2017-02-22 23:26:44 +01:00
|
|
|
/**
|
|
|
|
* Copyright (c) 2015-present, Facebook, Inc.
|
|
|
|
* All rights reserved.
|
|
|
|
*
|
2017-07-30 00:18:45 -07:00
|
|
|
* This source code is licensed under the BSD+Patents license found in the
|
2017-02-22 23:26:44 +01:00
|
|
|
* LICENSE file in the root directory of this source tree.
|
|
|
|
*/
|
|
|
|
|
|
|
|
// Copyright 2004-present Facebook. All Rights Reserved.
|
2017-11-22 05:11:28 -08:00
|
|
|
#pragma once
|
|
|
|
|
2017-02-22 23:26:44 +01:00
|
|
|
#include "../BlockSelectKernel.cuh"
|
|
|
|
#include "../Limits.cuh"
|
|
|
|
|
|
|
|
#define BLOCK_SELECT_DECL(TYPE, DIR, WARP_Q) \
|
|
|
|
extern void runBlockSelect_ ## TYPE ## _ ## DIR ## _ ## WARP_Q ## _( \
|
|
|
|
Tensor<TYPE, 2, true>& in, \
|
|
|
|
Tensor<TYPE, 2, true>& outK, \
|
|
|
|
Tensor<int, 2, true>& outV, \
|
|
|
|
bool dir, \
|
|
|
|
int k, \
|
2017-11-22 05:11:28 -08:00
|
|
|
cudaStream_t stream); \
|
|
|
|
\
|
|
|
|
extern void runBlockSelectPair_ ## TYPE ## _ ## DIR ## _ ## WARP_Q ## _( \
|
|
|
|
Tensor<TYPE, 2, true>& inK, \
|
|
|
|
Tensor<int, 2, true>& inV, \
|
|
|
|
Tensor<TYPE, 2, true>& outK, \
|
|
|
|
Tensor<int, 2, true>& outV, \
|
|
|
|
bool dir, \
|
|
|
|
int k, \
|
2017-02-22 23:26:44 +01:00
|
|
|
cudaStream_t stream)
|
|
|
|
|
|
|
|
#define BLOCK_SELECT_IMPL(TYPE, DIR, WARP_Q, THREAD_Q) \
|
|
|
|
void runBlockSelect_ ## TYPE ## _ ## DIR ## _ ## WARP_Q ## _( \
|
|
|
|
Tensor<TYPE, 2, true>& in, \
|
|
|
|
Tensor<TYPE, 2, true>& outK, \
|
|
|
|
Tensor<int, 2, true>& outV, \
|
|
|
|
bool dir, \
|
|
|
|
int k, \
|
|
|
|
cudaStream_t stream) { \
|
2017-11-22 05:11:28 -08:00
|
|
|
FAISS_ASSERT(in.getSize(0) == outK.getSize(0)); \
|
|
|
|
FAISS_ASSERT(in.getSize(0) == outV.getSize(0)); \
|
|
|
|
FAISS_ASSERT(outK.getSize(1) == k); \
|
|
|
|
FAISS_ASSERT(outV.getSize(1) == k); \
|
|
|
|
\
|
2017-02-22 23:26:44 +01:00
|
|
|
auto grid = dim3(in.getSize(0)); \
|
|
|
|
\
|
|
|
|
constexpr int kBlockSelectNumThreads = 128; \
|
|
|
|
auto block = dim3(kBlockSelectNumThreads); \
|
|
|
|
\
|
|
|
|
FAISS_ASSERT(k <= WARP_Q); \
|
|
|
|
FAISS_ASSERT(dir == DIR); \
|
|
|
|
\
|
|
|
|
auto kInit = dir ? Limits<TYPE>::getMin() : Limits<TYPE>::getMax(); \
|
|
|
|
auto vInit = -1; \
|
|
|
|
\
|
|
|
|
blockSelect<TYPE, int, DIR, WARP_Q, THREAD_Q, kBlockSelectNumThreads> \
|
|
|
|
<<<grid, block, 0, stream>>>(in, outK, outV, kInit, vInit, k); \
|
2017-06-21 06:54:28 -07:00
|
|
|
CUDA_TEST_ERROR(); \
|
2017-11-22 05:11:28 -08:00
|
|
|
} \
|
|
|
|
\
|
|
|
|
void runBlockSelectPair_ ## TYPE ## _ ## DIR ## _ ## WARP_Q ## _( \
|
|
|
|
Tensor<TYPE, 2, true>& inK, \
|
|
|
|
Tensor<int, 2, true>& inV, \
|
|
|
|
Tensor<TYPE, 2, true>& outK, \
|
|
|
|
Tensor<int, 2, true>& outV, \
|
|
|
|
bool dir, \
|
|
|
|
int k, \
|
|
|
|
cudaStream_t stream) { \
|
|
|
|
FAISS_ASSERT(inK.isSameSize(inV)); \
|
|
|
|
FAISS_ASSERT(outK.isSameSize(outV)); \
|
|
|
|
\
|
|
|
|
auto grid = dim3(inK.getSize(0)); \
|
|
|
|
\
|
|
|
|
constexpr int kBlockSelectNumThreads = 128; \
|
|
|
|
auto block = dim3(kBlockSelectNumThreads); \
|
|
|
|
\
|
|
|
|
FAISS_ASSERT(k <= WARP_Q); \
|
|
|
|
FAISS_ASSERT(dir == DIR); \
|
|
|
|
\
|
|
|
|
auto kInit = dir ? Limits<TYPE>::getMin() : Limits<TYPE>::getMax(); \
|
|
|
|
auto vInit = -1; \
|
|
|
|
\
|
|
|
|
blockSelectPair<TYPE, int, DIR, WARP_Q, THREAD_Q, kBlockSelectNumThreads> \
|
|
|
|
<<<grid, block, 0, stream>>>(inK, inV, outK, outV, kInit, vInit, k); \
|
|
|
|
CUDA_TEST_ERROR(); \
|
2017-02-22 23:26:44 +01:00
|
|
|
}
|
|
|
|
|
2017-11-22 05:11:28 -08:00
|
|
|
|
2017-02-22 23:26:44 +01:00
|
|
|
#define BLOCK_SELECT_CALL(TYPE, DIR, WARP_Q) \
|
|
|
|
runBlockSelect_ ## TYPE ## _ ## DIR ## _ ## WARP_Q ## _( \
|
|
|
|
in, outK, outV, dir, k, stream)
|
2017-11-22 05:11:28 -08:00
|
|
|
|
|
|
|
#define BLOCK_SELECT_PAIR_CALL(TYPE, DIR, WARP_Q) \
|
|
|
|
runBlockSelectPair_ ## TYPE ## _ ## DIR ## _ ## WARP_Q ## _( \
|
|
|
|
inK, inV, outK, outV, dir, k, stream)
|