Optimized SIMD interleaved IVF flat/SQ implementation (#1566)

Summary:
Pull Request resolved: https://github.com/facebookresearch/faiss/pull/1566

This diff converts the IVF flat and IVFSQ code to use an interleaved-by-32 format, the same as was added to the IVFPQ GPU implementation in D24064745. It also implements SQ6 on the GPU.

However, the new interleaved format is now enabled by default for GpuIndexIVFFlat and GpuIndexIVFScalarQuantizer, while the IVFPQ version is still opt-in until I can develop optimized PQ kernels.

For extension of the interleaved format to non-word size (8/16/32 bit) codes, arbitrary bit codes are packed in groups of 32 vectors, so each dimension of SQ6 for 32 vectors is packed into (32 * 6) / 8 = 24 bytes, and SQ4 packs into 16 bytes.

The new IVF code also fuses the k-selection kernel with the distance computation, so results per-(query, ivf list) are already k-selected. This improves the speed, especially at small query batch sizes which is as much as 84% faster. The float32 version at large nq batch size (the 16384) is 13% faster, though this is now running at the peak memory bandwidth of the GPU it seems and cannot go any faster as far as I can tell. There is still room for improvement with the sq8/6/4 versions which are at about 50% of peak; optimizing these I'll work on in subsequent diffs.

Performance numbers for nlist = 1000, nb = 10^6, nprobe = 16, dim = 128 at varying nq are (in milliseconds), with all versions compared against the old interleaved version, but sq6 compared against the CPU implementation:

```
float32
nq 1 new 0.420811 old 0.774816 speedup 1.8412446442702306x
nq 4 new 0.377007 old 0.573527 speedup 1.5212635309158717x
nq 16 new 0.474821 old 0.611986 speedup 1.2888772821758094x
nq 64 new 0.926598 old 1.124938 speedup 1.2140518326178127x
nq 256 new 2.918364 old 3.339133 speedup 1.1441797527655906x
nq 1024 new 11.097743 old 12.647599 speedup 1.1396550631961833x
nq 4096 new 43.828697 old 50.088993 speedup 1.142835549046781x
nq 16384 new 173.582674 old 196.415956 speedup 1.1315412504821765x
sq8
nq 1 new 0.419673 old 0.660393 speedup 1.5735894374906176x
nq 4 new 0.396551 old 0.55526 speedup 1.4002234264949527x
nq 16 new 0.437477 old 0.589546 speedup 1.3476045597825714x
nq 64 new 0.697084 old 0.889233 speedup 1.2756468373969279x
nq 256 new 1.904308 old 2.231102 speedup 1.1716077441254251x
nq 1024 new 6.539976 old 8.23596 speedup 1.2593257222962286x
nq 4096 new 25.524117 old 31.276868 speedup 1.2253849173313223x
nq 16384 new 101.992982 old 125.355406 speedup 1.2290591327156215x
sq6
nq 1 new 0.693262 old 2.007591 speedup 2.895861881943623x
nq 4 new 0.62342 old 3.049899 speedup 4.892205896506368x
nq 16 new 0.626906 old 2.760067 speedup 4.402680784679043x
nq 64 new 1.002582 old 7.152971 speedup 7.134549592951x
nq 256 new 2.806507 old 19.4322 speedup 6.923980592245094x
nq 1024 new 9.414069 old 65.208767 speedup 6.926735612411593x
nq 4096 new 36.099553 old 249.866567 speedup 6.921597256342759x
nq 16384 new 142.230624 old 1040.07494 speedup 7.312594930329491x
sq4
nq 1 new 0.46687 old 0.670675 speedup 1.436534795553366x
nq 4 new 0.436246 old 0.589663 speedup 1.351675430834896x
nq 16 new 0.473243 old 0.628914 speedup 1.3289451719306993x
nq 64 new 0.789141 old 1.018548 speedup 1.2907047029618282x
nq 256 new 2.314464 old 2.592711 speedup 1.1202209237214318x
nq 1024 new 8.203663 old 9.574067 speedup 1.167047817541993x
nq 4096 new 31.910841 old 37.19758 speedup 1.1656721927197093x
nq 16384 new 126.195179 old 147.004414 speedup 1.164897226382951x
```

This diff also provides a new method for packing data of uniform arbitrary bitwidth in parallel, where a warp uses warp shuffles to exchange data to the right lane which is then bit packed in the appropriate lane. Unpacking data happens in a similar fashion. This allows for coalesced memory loads and stores, instead of individual lanes having to read multiple bytes or words out of global or shared memory. This was the most difficult thing about this particular diff.

The new IVF layout is completely transparent to the user. When copying to/from a CPU index, the codes are converted as needed. This functionality is implicitly tested in all of the CPU <-> GPU copy tests for the index types that currently exist.

This diff also contains an optimization to the scalar quantizers to only require an int-to-float conversion and a single multiply-add as opposed to more operations previously, by rewriting vmin and vdiff at runtime in the kernel.

The old IVF flat code is still in the tree and is accessible by setting `interleavedLayout` to false in the config object. This will be deleted in a later diff as part of a cleanup when I am finally done with performance comparisons.

The diff also contains various other changes:

- new code to allow copying a Tensor to/from a std::vector which reduces the amount of boilerplate code required in some places
- fixes a bug where miscellaneous index API calls were not properly stream synchronized if the user was using a non-default stream (e.g., a pytorch provided stream). This would not have been noticed by any regular user for the wrapped index calls, but it would be noticed if you attempted to call some of the debugging functions (e.g., get the GPU codes). This is done by adding additional logic to the StandardGpuResources stream update functions to add the required synchronization if the user manually changes the stream
- function to retrieve encoded IVF data in either CPU native or GPU interleaved format
- the CPU scalar quantizer object now directly reports how many bits are in a single scalar code, as previously the only information was how many bytes were used for a full encoded vector

Reviewed By: mdouze

Differential Revision: D24862911

fbshipit-source-id: 9a92486306b4b0c6ac30e5cd22c1ffbb6ed2faf4
pull/1576/head
Jeff Johnson 2020-12-15 21:16:33 -08:00 committed by Facebook GitHub Bot
parent 218a6a9b90
commit 90c891b616
53 changed files with 2912 additions and 519 deletions

View File

@ -23,4 +23,4 @@ RUN cmake -B build \
-DCMAKE_CUDA_FLAGS="-gencode arch=compute_61,code=sm_61" \
.
RUN make -C build -j20
RUN make -C build -j8

View File

@ -115,6 +115,7 @@ jobs:
docker run --gpus all faiss make -C build test
docker run --gpus all faiss sh -c '(cd build/faiss/python; python3 setup.py install) && python3 -m unittest discover -s faiss/gpu/test -p "test_*"'
docker run --gpus all faiss sh -c '(cd build/faiss/python; python3 setup.py install) && python3 -m unittest discover -s faiss/gpu/test -p "torch_*.py"'
no_output_timeout: 20m
deploy_linux:
parameters:

View File

@ -23,10 +23,12 @@ target_sources(faiss PRIVATE
impl/BroadcastSum.cu
impl/Distance.cu
impl/FlatIndex.cu
impl/InterleavedCodes.cpp
impl/IVFAppend.cu
impl/IVFBase.cu
impl/IVFFlat.cu
impl/IVFFlatScan.cu
impl/IVFInterleaved.cu
impl/IVFPQ.cu
impl/IVFUtils.cu
impl/IVFUtilsSelect1.cu
@ -36,6 +38,14 @@ target_sources(faiss PRIVATE
impl/PQScanMultiPassPrecomputed.cu
impl/RemapIndices.cpp
impl/VectorResidual.cu
impl/scan/IVFInterleaved1.cu
impl/scan/IVFInterleaved32.cu
impl/scan/IVFInterleaved64.cu
impl/scan/IVFInterleaved128.cu
impl/scan/IVFInterleaved256.cu
impl/scan/IVFInterleaved512.cu
impl/scan/IVFInterleaved1024.cu
impl/scan/IVFInterleaved2048.cu
utils/BlockSelectFloat.cu
utils/BlockSelectHalf.cu
utils/DeviceUtils.cu
@ -176,4 +186,7 @@ endforeach()
find_package(CUDAToolkit REQUIRED)
find_package(OpenMP REQUIRED)
target_link_libraries(faiss PRIVATE OpenMP::OpenMP_CXX)
target_link_libraries(faiss PRIVATE CUDA::cudart CUDA::cublas)

View File

@ -60,9 +60,13 @@ class GpuIndexIVF : public GpuIndex {
virtual int getListLength(int listId) const = 0;
/// Return the encoded vector data contained in a particular inverted list,
/// for debugging purposes. This is represented in a CPU Faiss (IndexIVF*)
/// for debugging purposes.
/// If gpuFormat is true, the data is returned as it is encoded in the
/// GPU-side representation.
/// Otherwise, it is converted to the CPU format.
/// compliant format, while the native GPU format may differ.
virtual std::vector<uint8_t> getListVectorData(int listId) const = 0;
virtual std::vector<uint8_t>
getListVectorData(int listId, bool gpuFormat = false) const = 0;
/// Return the vector indices contained in a particular inverted list, for
/// debugging purposes.

View File

@ -87,6 +87,7 @@ GpuIndexIVFFlat::copyFrom(const faiss::IndexIVFFlat* index) {
index->metric_arg,
false, // no residual
nullptr, // no scalar quantizer
ivfFlatConfig_.interleavedLayout,
ivfFlatConfig_.indicesOptions,
config_.memorySpace));
@ -172,6 +173,7 @@ GpuIndexIVFFlat::train(Index::idx_t n, const float* x) {
this->metric_arg,
false, // no residual
nullptr, // no scalar quantizer
ivfFlatConfig_.interleavedLayout,
ivfFlatConfig_.indicesOptions,
config_.memorySpace));
@ -191,11 +193,11 @@ GpuIndexIVFFlat::getListLength(int listId) const {
}
std::vector<uint8_t>
GpuIndexIVFFlat::getListVectorData(int listId) const {
GpuIndexIVFFlat::getListVectorData(int listId, bool gpuFormat) const {
FAISS_ASSERT(index_);
DeviceScope scope(config_.device);
return index_->getListVectorData(listId);
return index_->getListVectorData(listId, gpuFormat);
}
std::vector<Index::idx_t>

View File

@ -19,6 +19,13 @@ class IVFFlat;
class GpuIndexFlat;
struct GpuIndexIVFFlatConfig : public GpuIndexIVFConfig {
inline GpuIndexIVFFlatConfig()
: interleavedLayout(true) {
}
/// Use the alternative memory layout for the IVF lists
/// (currently the default)
bool interleavedLayout;
};
/// Wrapper around the GPU implementation that looks like
@ -66,9 +73,13 @@ class GpuIndexIVFFlat : public GpuIndexIVF {
int getListLength(int listId) const override;
/// Return the encoded vector data contained in a particular inverted list,
/// for debugging purposes. This is represented in a CPU Faiss (IndexIVF*)
/// for debugging purposes.
/// If gpuFormat is true, the data is returned as it is encoded in the
/// GPU-side representation.
/// Otherwise, it is converted to the CPU format.
/// compliant format, while the native GPU format may differ.
std::vector<uint8_t> getListVectorData(int listId) const override;
std::vector<uint8_t>
getListVectorData(int listId, bool gpuFormat = false) const override;
/// Return the vector indices contained in a particular inverted list, for
/// debugging purposes.

View File

@ -104,7 +104,7 @@ GpuIndexIVFPQ::copyFrom(const faiss::IndexIVFPQ* index) {
bitsPerCode_,
ivfpqConfig_.useFloat16LookupTables,
ivfpqConfig_.useMMCodeDistance,
ivfpqConfig_.alternativeLayout,
ivfpqConfig_.interleavedLayout,
(float*) index->pq.centroids.data(),
ivfpqConfig_.indicesOptions,
config_.memorySpace));
@ -263,7 +263,7 @@ GpuIndexIVFPQ::trainResidualQuantizer_(Index::idx_t n, const float* x) {
bitsPerCode_,
ivfpqConfig_.useFloat16LookupTables,
ivfpqConfig_.useMMCodeDistance,
ivfpqConfig_.alternativeLayout,
ivfpqConfig_.interleavedLayout,
pq.centroids.data(),
ivfpqConfig_.indicesOptions,
config_.memorySpace));
@ -354,11 +354,11 @@ GpuIndexIVFPQ::getListLength(int listId) const {
}
std::vector<uint8_t>
GpuIndexIVFPQ::getListVectorData(int listId) const {
GpuIndexIVFPQ::getListVectorData(int listId, bool gpuFormat) const {
FAISS_ASSERT(index_);
DeviceScope scope(config_.device);
return index_->getListVectorData(listId);
return index_->getListVectorData(listId, gpuFormat);
}
std::vector<Index::idx_t>

View File

@ -23,7 +23,7 @@ struct GpuIndexIVFPQConfig : public GpuIndexIVFConfig {
inline GpuIndexIVFPQConfig()
: useFloat16LookupTables(false),
usePrecomputedTables(false),
alternativeLayout(false),
interleavedLayout(false),
useMMCodeDistance(false) {
}
@ -38,7 +38,7 @@ struct GpuIndexIVFPQConfig : public GpuIndexIVFConfig {
/// Use the alternative memory layout for the IVF lists
/// WARNING: this is a feature under development, do not use!
bool alternativeLayout;
bool interleavedLayout;
/// Use GEMM-backed computation of PQ code distances for the no precomputed
/// table version of IVFPQ.
@ -115,9 +115,13 @@ class GpuIndexIVFPQ : public GpuIndexIVF {
int getListLength(int listId) const override;
/// Return the encoded vector data contained in a particular inverted list,
/// for debugging purposes. This is represented in a CPU Faiss (IndexIVF*)
/// for debugging purposes.
/// If gpuFormat is true, the data is returned as it is encoded in the
/// GPU-side representation.
/// Otherwise, it is converted to the CPU format.
/// compliant format, while the native GPU format may differ.
std::vector<uint8_t> getListVectorData(int listId) const override;
std::vector<uint8_t>
getListVectorData(int listId, bool gpuFormat = false) const override;
/// Return the vector indices contained in a particular inverted list, for
/// debugging purposes.

View File

@ -100,6 +100,7 @@ GpuIndexIVFScalarQuantizer::copyFrom(
index->metric_arg,
by_residual,
&sq,
ivfSQConfig_.interleavedLayout,
ivfSQConfig_.indicesOptions,
config_.memorySpace));
@ -164,11 +165,12 @@ GpuIndexIVFScalarQuantizer::getListLength(int listId) const {
}
std::vector<uint8_t>
GpuIndexIVFScalarQuantizer::getListVectorData(int listId) const {
GpuIndexIVFScalarQuantizer::getListVectorData(int listId,
bool gpuFormat) const {
FAISS_ASSERT(index_);
DeviceScope scope(config_.device);
return index_->getListVectorData(listId);
return index_->getListVectorData(listId, gpuFormat);
}
std::vector<Index::idx_t>
@ -220,6 +222,7 @@ GpuIndexIVFScalarQuantizer::train(Index::idx_t n, const float* x) {
this->metric_arg,
by_residual,
&sq,
ivfSQConfig_.interleavedLayout,
ivfSQConfig_.indicesOptions,
config_.memorySpace));

View File

@ -18,6 +18,13 @@ class IVFFlat;
class GpuIndexFlat;
struct GpuIndexIVFScalarQuantizerConfig : public GpuIndexIVFConfig {
inline GpuIndexIVFScalarQuantizerConfig()
: interleavedLayout(true) {
}
/// Use the alternative memory layout for the IVF lists
/// (currently the default)
bool interleavedLayout;
};
/// Wrapper around the GPU implementation that looks like
@ -72,9 +79,13 @@ class GpuIndexIVFScalarQuantizer : public GpuIndexIVF {
int getListLength(int listId) const override;
/// Return the encoded vector data contained in a particular inverted list,
/// for debugging purposes. This is represented in a CPU Faiss (IndexIVF*)
/// for debugging purposes.
/// If gpuFormat is true, the data is returned as it is encoded in the
/// GPU-side representation.
/// Otherwise, it is converted to the CPU format.
/// compliant format, while the native GPU format may differ.
std::vector<uint8_t> getListVectorData(int listId) const override;
std::vector<uint8_t>
getListVectorData(int listId, bool gpuFormat = false) const override;
/// Return the vector indices contained in a particular inverted list, for
/// debugging purposes.

View File

@ -206,11 +206,44 @@ StandardGpuResourcesImpl::setPinnedMemory(size_t size) {
void
StandardGpuResourcesImpl::setDefaultStream(int device, cudaStream_t stream) {
if (isInitialized(device)) {
// A new series of calls may not be ordered with what was the previous
// stream, so if the stream being specified is different, then we need to
// ensure ordering between the two (new stream waits on old).
auto it = userDefaultStreams_.find(device);
cudaStream_t prevStream = nullptr;
if (it != userDefaultStreams_.end()) {
prevStream = it->second;
} else {
FAISS_ASSERT(defaultStreams_.count(device));
prevStream = defaultStreams_[device];
}
if (prevStream != stream) {
streamWait({stream}, {prevStream});
}
}
userDefaultStreams_[device] = stream;
}
void
StandardGpuResourcesImpl::revertDefaultStream(int device) {
if (isInitialized(device)) {
auto it = userDefaultStreams_.find(device);
if (it != userDefaultStreams_.end()) {
// There was a user stream set that we need to synchronize against
cudaStream_t prevStream = userDefaultStreams_[device];
FAISS_ASSERT(defaultStreams_.count(device));
cudaStream_t newStream = defaultStreams_[device];
streamWait({newStream}, {prevStream});
}
}
userDefaultStreams_.erase(device);
}

View File

@ -11,6 +11,8 @@
#include <faiss/gpu/utils/ConversionOperators.cuh>
#include <faiss/gpu/utils/DeviceTensor.cuh>
#include <faiss/gpu/utils/HostTensor.cuh>
#include <faiss/gpu/utils/PtxUtils.cuh>
#include <faiss/gpu/utils/WarpShuffles.cuh>
namespace faiss { namespace gpu {
@ -21,6 +23,7 @@ inline bool isSQSupported(ScalarQuantizer::QuantizerType qtype) {
case ScalarQuantizer::QuantizerType::QT_8bit_direct:
case ScalarQuantizer::QuantizerType::QT_4bit:
case ScalarQuantizer::QuantizerType::QT_4bit_uniform:
case ScalarQuantizer::QuantizerType::QT_6bit:
case ScalarQuantizer::QuantizerType::QT_fp16:
return true;
default:
@ -41,9 +44,8 @@ struct GpuScalarQuantizer : public ScalarQuantizer {
HostTensor<float, 1, true>
cpuTrained((float*) sq.trained.data(), {(int) sq.trained.size()});
// Just use the default stream, as we're allocating memory above in any case
gpuTrained.copyFrom(cpuTrained, 0);
CUDA_VERIFY(cudaStreamSynchronize(0));
auto stream = res->getDefaultStreamCurrentDevice();
gpuTrained.copyFrom(cpuTrained, stream);
}
// ScalarQuantizer::trained copied to GPU memory
@ -74,7 +76,7 @@ struct CodecFloat {
CodecFloat(int vecBytes) : bytesPerVec(vecBytes) { }
size_t getSmemSize(int dim) { return 0; }
inline __device__ void setSmem(float* smem, int dim) { }
inline __device__ void initKernel(float* smem, int dim) { }
inline __device__ void decode(void* data, int vec, int d,
float* out) const {
@ -100,6 +102,20 @@ struct CodecFloat {
// doesn't need implementing (kDimPerIter == 1)
}
//
// new implementation
//
using EncodeT = float;
static constexpr int kEncodeBits = 32;
inline __device__ EncodeT encodeNew(int dim, float v) const {
return v;
}
inline __device__ float decodeNew(int dim, EncodeT v) const {
return v;
}
int bytesPerVec;
};
@ -118,7 +134,7 @@ struct Codec<ScalarQuantizer::QuantizerType::QT_fp16, 1> {
Codec(int vecBytes) : bytesPerVec(vecBytes) { }
size_t getSmemSize(int dim) { return 0; }
inline __device__ void setSmem(float* smem, int dim) { }
inline __device__ void initKernel(float* smem, int dim) { }
inline __device__ void decode(void* data, int vec, int d,
float* out) const {
@ -144,50 +160,18 @@ struct Codec<ScalarQuantizer::QuantizerType::QT_fp16, 1> {
// doesn't need implementing (kDimPerIter == 1)
}
int bytesPerVec;
};
//
// new implementation
//
using EncodeT = half;
static constexpr int kEncodeBits = 16;
// dim % 2 == 0, ensures uint32 alignment
template <>
struct Codec<ScalarQuantizer::QuantizerType::QT_fp16, 2> {
/// How many dimensions per iteration we are handling for encoding or decoding
static constexpr int kDimPerIter = 2;
Codec(int vecBytes) : bytesPerVec(vecBytes) { }
size_t getSmemSize(int dim) { return 0; }
inline __device__ void setSmem(float* smem, int dim) { }
inline __device__ void decode(void* data, int vec, int d,
float* out) const {
half2* p = (half2*) &((uint8_t*) data)[vec * bytesPerVec];
half2 pd = p[d];
out[0] = Convert<half, float>()(__low2half(pd));
out[1] = Convert<half, float>()(__high2half(pd));
inline __device__ EncodeT encodeNew(int dim, float v) const {
return Convert<float, half>()(v);
}
inline __device__ float decodePartial(void* data, int vec, int d,
int subD) const {
// should not be called
assert(false);
return 0;
}
inline __device__ void encode(void* data, int vec, int d,
float v[kDimPerIter]) const {
half2* p = (half2*) &((uint8_t*) data)[vec * bytesPerVec];
half h0 = Convert<float, half>()(v[0]);
half h1 = Convert<float, half>()(v[1]);
p[d] = __halves2half2(h0, h1);
}
inline __device__ void encodePartial(void* data, int vec, int d,
int remaining,
float v[kDimPerIter]) const {
// should not be called
assert(false);
inline __device__ float decodeNew(int dim, EncodeT v) const {
return Convert<half, float>()(v);
}
int bytesPerVec;
@ -223,11 +207,18 @@ struct Codec<ScalarQuantizer::QuantizerType::QT_8bit_uniform, DimMultiple> {
}
size_t getSmemSize(int dim) { return 0; }
inline __device__ void setSmem(float* smem, int dim) { }
inline __device__ void initKernel(float* smem, int dim) {
// We are performing vmin + vdiff * (v + 0.5) / (2^bits - 1)
// This can be simplified to vmin' + vdiff' * v where:
// vdiff' = vdiff / (2^bits - 1)
// vmin' = vmin + 0.5 * vdiff'
auto vd = vdiff * (1.0f / 255.0f);
vmin = vmin + 0.5f * vd;
vdiff = vd;
}
inline __device__ float decodeHelper(uint8_t v) const {
float x = (((float) v) + 0.5f) / 255.0f;
return vmin + x * vdiff;
return vmin + (float) v * vdiff;
}
inline __device__ void decode(void* data, int vec, int d,
@ -267,7 +258,7 @@ struct Codec<ScalarQuantizer::QuantizerType::QT_8bit_uniform, DimMultiple> {
inline __device__ uint8_t encodeHelper(float v) const {
float x = (v - vmin) / vdiff;
x = fminf(1.0f, fmaxf(0.0f, x));
return (uint8_t) (255 * x);
return (uint8_t) (x * 255.0f);
}
inline __device__ void encode(void* data, int vec, int d,
@ -300,9 +291,23 @@ struct Codec<ScalarQuantizer::QuantizerType::QT_8bit_uniform, DimMultiple> {
// otherwise does not need implementing
}
//
// interleaved code implementation
//
using EncodeT = uint8_t;
static constexpr int kEncodeBits = 8;
inline __device__ EncodeT encodeNew(int dim, float v) const {
return encodeHelper(v);
}
inline __device__ float decodeNew(int dim, EncodeT v) const {
return decodeHelper(v);
}
int bytesPerVec;
const float vmin;
const float vdiff;
float vmin;
float vdiff;
};
// Uniform quantization per each dimension
@ -322,19 +327,26 @@ struct Codec<ScalarQuantizer::QuantizerType::QT_8bit, DimMultiple> {
return sizeof(float) * dim * 2;
}
inline __device__ void setSmem(float* smem, int dim) {
// Initialize shared memory and local storage
// It is up to the user to call a trailing syncthreads (after any other
// initialization required has been done)
inline __device__ void initKernel(float* smem, int dim) {
smemVmin = smem;
smemVdiff = smem + dim;
for (int i = threadIdx.x; i < dim; i += blockDim.x) {
smemVmin[i] = vmin[i];
smemVdiff[i] = vdiff[i];
// We are performing vmin + vdiff * (v + 0.5) / (2^bits - 1)
// This can be simplified to vmin' + vdiff' * v where:
// vdiff' = vdiff / (2^bits - 1)
// vmin' = vmin + 0.5 * vdiff'
auto vd = vdiff[i] * (1.0f / 255.0f);
smemVmin[i] = vmin[i] + 0.5f * vd;
smemVdiff[i] = vd;
}
}
inline __device__ float decodeHelper(uint8_t v, int realDim) const {
float x = (((float) v) + 0.5f) / 255.0f;
return smemVmin[realDim] + x * smemVdiff[realDim];
return smemVmin[realDim] + (float) v * smemVdiff[realDim];
}
inline __device__ void decode(void* data, int vec, int d,
@ -375,7 +387,7 @@ struct Codec<ScalarQuantizer::QuantizerType::QT_8bit, DimMultiple> {
inline __device__ uint8_t encodeHelper(float v, int realDim) const {
float x = (v - vmin[realDim]) / vdiff[realDim];
x = fminf(1.0f, fmaxf(0.0f, x));
return (uint8_t) (255 * x);
return (uint8_t) (x * 255.0f);
}
inline __device__ void encode(void* data, int vec, int d,
@ -409,6 +421,20 @@ struct Codec<ScalarQuantizer::QuantizerType::QT_8bit, DimMultiple> {
// otherwise does not need implementing
}
//
// interleaved code implementation
//
using EncodeT = uint8_t;
static constexpr int kEncodeBits = 8;
inline __device__ EncodeT encodeNew(int dim, float v) const {
return encodeHelper(v, dim);
}
inline __device__ float decodeNew(int dim, EncodeT v) const {
return decodeHelper(v, dim);
}
int bytesPerVec;
// gmem pointers
@ -428,7 +454,7 @@ struct Codec<ScalarQuantizer::QuantizerType::QT_8bit_direct, 1> {
Codec(int vecBytes) : bytesPerVec(vecBytes) { }
size_t getSmemSize(int dim) { return 0; }
inline __device__ void setSmem(float* smem, int dim) { }
inline __device__ void initKernel(float* smem, int dim) { }
inline __device__ void decode(void* data, int vec, int d,
float* out) const {
@ -454,9 +480,94 @@ struct Codec<ScalarQuantizer::QuantizerType::QT_8bit_direct, 1> {
// doesn't need implementing (kDimPerIter == 1)
}
//
// interleaved code implementation
//
using EncodeT = uint8_t;
static constexpr int kEncodeBits = 8;
inline __device__ EncodeT encodeNew(int dim, float v) const {
return (uint8_t) v;
}
inline __device__ float decodeNew(int dim, EncodeT v) const {
return (float) v;
}
int bytesPerVec;
};
/////
//
// 6 bit encodings
//
/////
template <>
struct Codec<ScalarQuantizer::QuantizerType::QT_6bit, 1> {
Codec(int vecBytes, float* min, float* diff)
: bytesPerVec(vecBytes), vmin(min), vdiff(diff),
smemVmin(nullptr),
smemVdiff(nullptr) {
}
size_t getSmemSize(int dim) {
return sizeof(float) * dim * 2;
}
// Initialize shared memory and local storage
// It is up to the user to call a trailing syncthreads (after any other
// initialization required has been done)
inline __device__ void initKernel(float* smem, int dim) {
smemVmin = smem;
smemVdiff = smem + dim;
for (int i = threadIdx.x; i < dim; i += blockDim.x) {
// We are performing vmin + vdiff * (v + 0.5) / (2^bits - 1)
// This can be simplified to vmin' + vdiff' * v where:
// vdiff' = vdiff / (2^bits - 1)
// vmin' = vmin + 0.5 * vdiff'
auto vd = vdiff[i] * (1.0f / 63.0f);
smemVmin[i] = vmin[i] + 0.5f * vd;
smemVdiff[i] = vd;
}
}
inline __device__ float decodeHelper(uint8_t v, int realDim) const {
return smemVmin[realDim] + (float) v * smemVdiff[realDim];
}
inline __device__ uint8_t encodeHelper(float v, int realDim) const {
float x = (v - vmin[realDim]) / vdiff[realDim];
x = fminf(1.0f, fmaxf(0.0f, x));
return (uint8_t) (x * 63.0f);
}
//
// interleaved code implementation
//
using EncodeT = uint8_t;
static constexpr int kEncodeBits = 6;
inline __device__ EncodeT encodeNew(int dim, float v) const {
return encodeHelper(v, dim);
}
inline __device__ float decodeNew(int dim, EncodeT v) const {
return decodeHelper(v, dim);
}
int bytesPerVec;
// gmem pointers
const float* vmin;
const float* vdiff;
// smem pointers
float* smemVmin;
float* smemVdiff;
};
/////
//
// 4 bit encodings
@ -474,11 +585,18 @@ struct Codec<ScalarQuantizer::QuantizerType::QT_4bit_uniform, 1> {
}
size_t getSmemSize(int dim) { return 0; }
inline __device__ void setSmem(float* smem, int dim) { }
inline __device__ void initKernel(float* smem, int dim) {
// We are performing vmin + vdiff * (v + 0.5) / (2^bits - 1)
// This can be simplified to vmin' + vdiff' * v where:
// vdiff' = vdiff / (2^bits - 1)
// vmin' = vmin + 0.5 * vdiff'
auto vd = vdiff * (1.0f / 15.0f);
vmin = vmin + 0.5f * vd;
vdiff = vd;
}
inline __device__ float decodeHelper(uint8_t v) const {
float x = (((float) v) + 0.5f) / 15.0f;
return vmin + x * vdiff;
return vmin + (float) v * vdiff;
}
inline __device__ void decode(void* data, int vec, int d,
@ -519,9 +637,23 @@ struct Codec<ScalarQuantizer::QuantizerType::QT_4bit_uniform, 1> {
p[d] = encodeHelper(v[0]);
}
//
// interleaved code implementation
//
using EncodeT = uint8_t;
static constexpr int kEncodeBits = 4;
inline __device__ EncodeT encodeNew(int dim, float v) const {
return encodeHelper(v);
}
inline __device__ float decodeNew(int dim, EncodeT v) const {
return decodeHelper(v);
}
int bytesPerVec;
const float vmin;
const float vdiff;
float vmin;
float vdiff;
};
template <>
@ -539,19 +671,25 @@ struct Codec<ScalarQuantizer::QuantizerType::QT_4bit, 1> {
return sizeof(float) * dim * 2;
}
inline __device__ void setSmem(float* smem, int dim) {
inline __device__ void initKernel(float* smem, int dim) {
smemVmin = smem;
smemVdiff = smem + dim;
for (int i = threadIdx.x; i < dim; i += blockDim.x) {
smemVmin[i] = vmin[i];
smemVdiff[i] = vdiff[i];
// We are performing vmin + vdiff * (v + 0.5) / (2^bits - 1)
// This can be simplified to vmin' + vdiff' * v where:
// vdiff' = vdiff / (2^bits - 1)
// vmin' = vmin + 0.5 * vdiff'
auto vd = vdiff[i] / 15.0f;
smemVmin[i] = vmin[i] + 0.5f * vd;
smemVdiff[i] = vd;
}
__syncthreads();
}
inline __device__ float decodeHelper(uint8_t v, int realDim) const {
float x = (((float) v) + 0.5f) / 15.0f;
return smemVmin[realDim] + x * smemVdiff[realDim];
return smemVmin[realDim] + (float) v * smemVdiff[realDim];
}
inline __device__ void decode(void* data, int vec, int d,
@ -597,6 +735,20 @@ struct Codec<ScalarQuantizer::QuantizerType::QT_4bit, 1> {
p[d] = encodeHelper(v[0], realDim);
}
//
// interleaved code implementation
//
using EncodeT = uint8_t;
static constexpr int kEncodeBits = 4;
inline __device__ EncodeT encodeNew(int dim, float v) const {
return encodeHelper(v, dim);
}
inline __device__ float decodeNew(int dim, EncodeT v) const {
return decodeHelper(v, dim);
}
int bytesPerVec;
// gmem pointers

View File

@ -9,9 +9,13 @@
#include <faiss/gpu/impl/IVFAppend.cuh>
#include <faiss/impl/FaissAssert.h>
#include <faiss/gpu/utils/Float16.cuh>
#include <faiss/gpu/utils/DeviceDefs.cuh>
#include <faiss/gpu/utils/DeviceUtils.h>
#include <faiss/gpu/utils/PtxUtils.cuh>
#include <faiss/gpu/utils/Tensor.cuh>
#include <faiss/gpu/utils/StaticUtils.h>
#include <faiss/gpu/utils/WarpPackedBits.cuh>
#include <faiss/gpu/utils/WarpShuffles.cuh>
namespace faiss { namespace gpu {
@ -19,6 +23,7 @@ namespace faiss { namespace gpu {
// IVF list length update
//
// Updates the device-size array of list start pointers for codes and indices
__global__ void
runUpdateListPointers(Tensor<int, 1, true> listIds,
Tensor<int, 1, true> newListLength,
@ -61,6 +66,65 @@ runUpdateListPointers(Tensor<int, 1, true>& listIds,
CUDA_TEST_ERROR();
}
// Appends new indices for vectors being added to the IVF indices lists
__global__ void
ivfIndicesAppend(Tensor<int, 1, true> listIds,
Tensor<int, 1, true> listOffset,
Tensor<Index::idx_t, 1, true> indices,
IndicesOptions opt,
void** listIndices) {
int vec = blockIdx.x * blockDim.x + threadIdx.x;
if (vec >= listIds.getSize(0)) {
return;
}
int listId = listIds[vec];
int offset = listOffset[vec];
// Add vector could be invalid (contains NaNs etc)
if (listId == -1 || offset == -1) {
return;
}
auto index = indices[vec];
if (opt == INDICES_32_BIT) {
// FIXME: there could be overflow here, but where should we check this?
((int*) listIndices[listId])[offset] = (int) index;
} else if (opt == INDICES_64_BIT) {
((Index::idx_t*) listIndices[listId])[offset] = index;
}
}
void
runIVFIndicesAppend(Tensor<int, 1, true>& listIds,
Tensor<int, 1, true>& listOffset,
Tensor<Index::idx_t, 1, true>& indices,
IndicesOptions opt,
thrust::device_vector<void*>& listIndices,
cudaStream_t stream) {
FAISS_ASSERT(opt == INDICES_CPU ||
opt == INDICES_IVF ||
opt == INDICES_32_BIT ||
opt == INDICES_64_BIT);
if (opt != INDICES_CPU && opt != INDICES_IVF) {
int num = listIds.getSize(0);
int threads = std::min(num, getMaxThreadsCurrentDevice());
int blocks = utils::divUp(num, threads);
ivfIndicesAppend<<<blocks, threads, 0, stream>>>(
listIds,
listOffset,
indices,
opt,
listIndices.data().get());
CUDA_TEST_ERROR();
}
}
//
// IVF PQ append
//
@ -69,11 +133,8 @@ __global__ void
ivfpqInvertedListAppend(Tensor<int, 1, true> listIds,
Tensor<int, 1, true> listOffset,
Tensor<int, 2, true> encodings,
Tensor<Index::idx_t, 1, true> indices,
IndicesOptions opt,
bool layoutBy32,
void** listCodes,
void** listIndices) {
void** listCodes) {
int encodingToAdd = blockIdx.x * blockDim.x + threadIdx.x;
if (encodingToAdd >= listIds.getSize(0)) {
@ -89,16 +150,6 @@ ivfpqInvertedListAppend(Tensor<int, 1, true> listIds,
}
auto encoding = encodings[encodingToAdd];
auto index = indices[encodingToAdd];
if (opt == INDICES_32_BIT) {
// FIXME: there could be overflow here, but where should we check this?
((int*) listIndices[listId])[vectorNumInList] = (int) index;
} else if (opt == INDICES_64_BIT) {
((Index::idx_t*) listIndices[listId])[vectorNumInList] = index;
} else {
// INDICES_CPU or INDICES_IVF; no indices are being stored
}
if (layoutBy32) {
// Layout with 32 vectors interleaved
@ -109,6 +160,8 @@ ivfpqInvertedListAppend(Tensor<int, 1, true> listIds,
uint8_t* codeStart = ((uint8_t*) listCodes[listId]) + start;
// This is actually properly warp coalesced, as each thread handles a
// different vector
for (int i = 0; i < encodings.getSize(1); ++i) {
codeStart[i * 32] = (uint8_t) encoding[i];
}
@ -135,22 +188,16 @@ runIVFPQInvertedListAppend(Tensor<int, 1, true>& listIds,
thrust::device_vector<void*>& listIndices,
IndicesOptions indicesOptions,
cudaStream_t stream) {
int numThreads = std::min(listIds.getSize(0), getMaxThreadsCurrentDevice());
int numBlocks = utils::divUp(listIds.getSize(0), numThreads);
// Append the indices that we're about to add, if any
runIVFIndicesAppend(listIds, listOffset, indices,
indicesOptions, listIndices, stream);
dim3 grid(numBlocks);
dim3 block(numThreads);
// Append the PQ codes
int threads = std::min(listIds.getSize(0), getMaxThreadsCurrentDevice());
int blocks = utils::divUp(listIds.getSize(0), threads);
FAISS_ASSERT(indicesOptions == INDICES_CPU ||
indicesOptions == INDICES_IVF ||
indicesOptions == INDICES_32_BIT ||
indicesOptions == INDICES_64_BIT);
ivfpqInvertedListAppend<<<grid, block, 0, stream>>>(
listIds, listOffset, encodings, indices,
indicesOptions, layoutBy32,
listCodes.data().get(),
listIndices.data().get());
ivfpqInvertedListAppend<<<threads, blocks, 0, stream>>>(
listIds, listOffset, encodings, layoutBy32, listCodes.data().get());
CUDA_TEST_ERROR();
}
@ -159,43 +206,126 @@ runIVFPQInvertedListAppend(Tensor<int, 1, true>& listIds,
// IVF flat append
//
template <typename Codec>
__global__ void
ivfFlatIndicesAppend(Tensor<int, 1, true> listIds,
Tensor<int, 1, true> listOffset,
Tensor<Index::idx_t, 1, true> indices,
IndicesOptions opt,
void** listIndices) {
int vec = blockIdx.x * blockDim.x + threadIdx.x;
ivfFlatInterleavedAppend(// the IDs (offset in listData) of the unique lists
// being added to
Tensor<int, 1, true> uniqueLists,
// For each of the list IDs in uniqueLists, the start
// offset in vectorsByUniqueList for the vectors that
// we are adding to that list
Tensor<int, 1, true> uniqueListVectorStart,
// IDs in vecs of the vectors being added to each
// unique list
// The vectors (offset in vecs) added to
// uniqueLists[i] is:
// {vBUL[uLVS[i]], ..., vBUL[uLVS[i+1] - 1]}
Tensor<int, 1, true> vectorsByUniqueList,
// For each of the list IDs in uniqueLists, the start
// offset (by vector) within that list where we begin
// appending
Tensor<int, 1, true> uniqueListStartOffset,
// The vectors being added
Tensor<float, 2, true> vecs,
// The set of addresses for each of the lists
void** listData,
Codec codec) {
using EncodeT = typename Codec::EncodeT;
if (vec >= listIds.getSize(0)) {
return;
}
// FIXME: some issue with getLaneId() and CUDA 10.1 and P4 GPUs?
int laneId = threadIdx.x % kWarpSize;
int warpId = threadIdx.x / kWarpSize;
int warpsPerBlock = blockDim.x / kWarpSize;
int listId = listIds[vec];
int offset = listOffset[vec];
// Each block is dedicated to a separate list
int listId = uniqueLists[blockIdx.x];
// Add vector could be invalid (contains NaNs etc)
if (listId == -1 || offset == -1) {
return;
}
// The vecs we add to the list are at indices [vBUL[vecIdStart], vBUL[vecIdEnd])
int vecIdStart = uniqueListVectorStart[blockIdx.x];
// uLVS is explicitly terminated for us with one more than the number of
// blocks that we have
int vecIdEnd = uniqueListVectorStart[blockIdx.x + 1];
auto index = indices[vec];
// How many vectors we are adding to this list
int numVecsAdding = vecIdEnd - vecIdStart;
if (opt == INDICES_32_BIT) {
// FIXME: there could be overflow here, but where should we check this?
((int*) listIndices[listId])[offset] = (int) index;
} else if (opt == INDICES_64_BIT) {
((Index::idx_t*) listIndices[listId])[offset] = index;
// The first vector we are updating within the list
auto listVecStart = uniqueListStartOffset[blockIdx.x];
// These are the actual vec IDs that we are adding (in vecs)
int* listVecIds = vectorsByUniqueList[vecIdStart].data();
// All data is written by groups of 32 vectors (to mirror the warp).
// listVecStart could be in the middle of this, or even, for sub-byte
// encodings, mean that the first vector piece of data that we need to update
// is in the high part of a byte.
//
// WarpPackedBits allows writing of arbitrary bit packed data in groups of 32,
// but we ensure that it only operates on the group of 32 vectors.
// In order to do this we need to actually start updating vectors at the next
// lower multiple of 32 from listVecStart.
int alignedListVecStart = utils::roundDown(listVecStart, 32);
// Each block of 32 vectors fully encodes into this many bytes
constexpr int bytesPerVectorBlockDim = Codec::kEncodeBits * 32 / 8;
constexpr int wordsPerVectorBlockDim = bytesPerVectorBlockDim / sizeof(EncodeT);
int wordsPerVectorBlock = wordsPerVectorBlockDim * vecs.getSize(1);
EncodeT* listStart = ((EncodeT*) listData[listId]);
// Each warp within the block handles a different chunk of 32
int warpVec = alignedListVecStart + warpId * 32;
// The warp data starts here
EncodeT* warpData = listStart + (warpVec / 32) * wordsPerVectorBlock;
// Each warp encodes a single block
for (; warpVec < listVecStart + numVecsAdding;
// but block stride
warpVec += blockDim.x,
// the new warp data base strides by how many vector blocks we are
// encoding, which is one per warp
warpData += warpsPerBlock * wordsPerVectorBlock) {
// This lane is adding this vec (if it is within bounds)
int laneVec = warpVec + laneId;
// Which vector does this correspond to in the set of vectors that we need
// to add?
// If this is < 0, then this particular thread is not encoding / appending a
// new vector
int laneVecAdding = laneVec - listVecStart;
// We are actually adding a new vector if this is within range
bool valid = laneVecAdding >= 0 && laneVecAdding < numVecsAdding;
// Now, which actual vector in vecs is this?
int vecId = valid ? listVecIds[laneVecAdding] : 0;
// Each warp that has some vector data available needs to write out the
// vector components
EncodeT* data = warpData;
for (int dim = 0;
dim < vecs.getSize(1); ++dim,
data += wordsPerVectorBlockDim) {
EncodeT enc = 0;
if (valid) {
enc = codec.encodeNew(dim, vecs[vecId][dim]);
}
WarpPackedBits<EncodeT, Codec::kEncodeBits>::
write(laneId, enc, valid, data);
}
}
}
template <typename Codec>
__global__ void
ivfFlatInvertedListAppend(Tensor<int, 1, true> listIds,
Tensor<int, 1, true> listOffset,
Tensor<float, 2, true> vecs,
void** listData,
Codec codec) {
ivfFlatAppend(Tensor<int, 1, true> listIds,
Tensor<int, 1, true> listOffset,
Tensor<float, 2, true> vecs,
void** listData,
Codec codec) {
int vec = blockIdx.x;
int listId = listIds[vec];
@ -247,43 +377,36 @@ ivfFlatInvertedListAppend(Tensor<int, 1, true> listIds,
}
void
runIVFFlatInvertedListAppend(Tensor<int, 1, true>& listIds,
Tensor<int, 1, true>& listOffset,
Tensor<float, 2, true>& vecs,
Tensor<Index::idx_t, 1, true>& indices,
bool useResidual,
Tensor<float, 2, true>& residuals,
GpuScalarQuantizer* scalarQ,
thrust::device_vector<void*>& listData,
thrust::device_vector<void*>& listIndices,
IndicesOptions indicesOptions,
cudaStream_t stream) {
runIVFFlatInterleavedAppend(Tensor<int, 1, true>& listIds,
Tensor<int, 1, true>& listOffset,
Tensor<int, 1, true>& uniqueLists,
Tensor<int, 1, true>& vectorsByUniqueList,
Tensor<int, 1, true>& uniqueListVectorStart,
Tensor<int, 1, true>& uniqueListStartOffset,
Tensor<float, 2, true>& vecs,
Tensor<Index::idx_t, 1, true>& indices,
GpuScalarQuantizer* scalarQ,
thrust::device_vector<void*>& listData,
thrust::device_vector<void*>& listIndices,
IndicesOptions indicesOptions,
cudaStream_t stream) {
int dim = vecs.getSize(1);
int maxThreads = getMaxThreadsCurrentDevice();
// First, append the indices that we're about to add, if any
if (indicesOptions != INDICES_CPU && indicesOptions != INDICES_IVF) {
int blocks = utils::divUp(vecs.getSize(0), maxThreads);
runIVFIndicesAppend(listIds, listOffset, indices,
indicesOptions, listIndices, stream);
ivfFlatIndicesAppend<<<blocks, maxThreads, 0, stream>>>(
listIds,
listOffset,
indices,
indicesOptions,
listIndices.data().get());
}
// Each block will handle appending a single vector
#define RUN_APPEND \
do { \
dim3 grid(vecs.getSize(0)); \
dim3 block(std::min(dim / codec.kDimPerIter, maxThreads)); \
\
ivfFlatInvertedListAppend \
dim3 grid(uniqueLists.getSize(0)); \
dim3 block(128); \
ivfFlatInterleavedAppend \
<<<grid, block, 0, stream>>>( \
listIds, \
listOffset, \
useResidual ? residuals : vecs, \
uniqueLists, \
uniqueListVectorStart, \
vectorsByUniqueList, \
uniqueListStartOffset, \
vecs, \
listData.data().get(), \
codec); \
} while (0)
@ -295,48 +418,127 @@ runIVFFlatInvertedListAppend(Tensor<int, 1, true>& listIds,
switch (scalarQ->qtype) {
case ScalarQuantizer::QuantizerType::QT_8bit:
{
if (false) {
// if (dim % 4 == 0) {
Codec<ScalarQuantizer::QuantizerType::QT_8bit, 4>
codec(scalarQ->code_size,
scalarQ->gpuTrained.data(),
scalarQ->gpuTrained.data() + dim);
RUN_APPEND;
} else {
Codec<ScalarQuantizer::QuantizerType::QT_8bit, 1>
codec(scalarQ->code_size,
scalarQ->gpuTrained.data(),
scalarQ->gpuTrained.data() + dim);
RUN_APPEND;
}
Codec<ScalarQuantizer::QuantizerType::QT_8bit, 1>
codec(scalarQ->code_size,
scalarQ->gpuTrained.data(),
scalarQ->gpuTrained.data() + dim);
RUN_APPEND;
}
break;
case ScalarQuantizer::QuantizerType::QT_8bit_uniform:
{
// if (dim % 4 == 0) {
if (false) {
Codec<ScalarQuantizer::QuantizerType::QT_8bit_uniform, 4>
codec(scalarQ->code_size, scalarQ->trained[0], scalarQ->trained[1]);
RUN_APPEND;
} else {
Codec<ScalarQuantizer::QuantizerType::QT_8bit_uniform, 1>
codec(scalarQ->code_size, scalarQ->trained[0], scalarQ->trained[1]);
RUN_APPEND;
}
Codec<ScalarQuantizer::QuantizerType::QT_8bit_uniform, 1>
codec(scalarQ->code_size, scalarQ->trained[0], scalarQ->trained[1]);
RUN_APPEND;
}
break;
case ScalarQuantizer::QuantizerType::QT_fp16:
{
// if (dim % 2 == 0) {
if (false) {
Codec<ScalarQuantizer::QuantizerType::QT_fp16, 2>
codec(scalarQ->code_size);
RUN_APPEND;
} else {
Codec<ScalarQuantizer::QuantizerType::QT_fp16, 1>
codec(scalarQ->code_size);
RUN_APPEND;
}
Codec<ScalarQuantizer::QuantizerType::QT_fp16, 1>
codec(scalarQ->code_size);
RUN_APPEND;
}
break;
case ScalarQuantizer::QuantizerType::QT_8bit_direct:
{
Codec<ScalarQuantizer::QuantizerType::QT_8bit_direct, 1>
codec(scalarQ->code_size);
RUN_APPEND;
}
break;
case ScalarQuantizer::QuantizerType::QT_6bit:
{
Codec<ScalarQuantizer::QuantizerType::QT_6bit, 1>
codec(scalarQ->code_size,
scalarQ->gpuTrained.data(),
scalarQ->gpuTrained.data() + dim);
RUN_APPEND;
}
break;
case ScalarQuantizer::QuantizerType::QT_4bit:
{
Codec<ScalarQuantizer::QuantizerType::QT_4bit, 1>
codec(scalarQ->code_size,
scalarQ->gpuTrained.data(),
scalarQ->gpuTrained.data() + dim);
RUN_APPEND;
}
break;
case ScalarQuantizer::QuantizerType::QT_4bit_uniform:
{
Codec<ScalarQuantizer::QuantizerType::QT_4bit_uniform, 1>
codec(scalarQ->code_size, scalarQ->trained[0], scalarQ->trained[1]);
RUN_APPEND;
}
break;
default:
// unimplemented, should be handled at a higher level
FAISS_ASSERT(false);
}
}
CUDA_TEST_ERROR();
#undef RUN_APPEND
}
void
runIVFFlatAppend(Tensor<int, 1, true>& listIds,
Tensor<int, 1, true>& listOffset,
Tensor<float, 2, true>& vecs,
Tensor<Index::idx_t, 1, true>& indices,
GpuScalarQuantizer* scalarQ,
thrust::device_vector<void*>& listData,
thrust::device_vector<void*>& listIndices,
IndicesOptions indicesOptions,
cudaStream_t stream) {
int dim = vecs.getSize(1);
int maxThreads = getMaxThreadsCurrentDevice();
// First, append the indices that we're about to add, if any
runIVFIndicesAppend(listIds, listOffset, indices,
indicesOptions, listIndices, stream);
// Each block will handle appending a single vector
#define RUN_APPEND \
do { \
dim3 grid(vecs.getSize(0)); \
dim3 block(std::min(dim / codec.kDimPerIter, maxThreads)); \
ivfFlatAppend \
<<<grid, block, 0, stream>>>( \
listIds, \
listOffset, \
vecs, \
listData.data().get(), \
codec); \
} while (0)
if (!scalarQ) {
CodecFloat codec(dim * sizeof(float));
RUN_APPEND;
} else {
switch (scalarQ->qtype) {
case ScalarQuantizer::QuantizerType::QT_8bit:
{
Codec<ScalarQuantizer::QuantizerType::QT_8bit, 1>
codec(scalarQ->code_size,
scalarQ->gpuTrained.data(),
scalarQ->gpuTrained.data() + dim);
RUN_APPEND;
}
break;
case ScalarQuantizer::QuantizerType::QT_8bit_uniform:
{
Codec<ScalarQuantizer::QuantizerType::QT_8bit_uniform, 1>
codec(scalarQ->code_size, scalarQ->trained[0], scalarQ->trained[1]);
RUN_APPEND;
}
break;
case ScalarQuantizer::QuantizerType::QT_fp16:
{
Codec<ScalarQuantizer::QuantizerType::QT_fp16, 1>
codec(scalarQ->code_size);
RUN_APPEND;
}
break;
case ScalarQuantizer::QuantizerType::QT_8bit_direct:
@ -373,4 +575,5 @@ runIVFFlatInvertedListAppend(Tensor<int, 1, true>& listIds,
#undef RUN_APPEND
}
} } // namespace

View File

@ -25,9 +25,7 @@ void runUpdateListPointers(Tensor<int, 1, true>& listIds,
thrust::device_vector<void*>& listIndices,
cudaStream_t stream);
/// Actually append the new codes / vector indices to the individual lists
/// IVFPQ
/// Append PQ codes to IVF lists
void runIVFPQInvertedListAppend(Tensor<int, 1, true>& listIds,
Tensor<int, 1, true>& listOffset,
Tensor<int, 2, true>& encodings,
@ -38,17 +36,30 @@ void runIVFPQInvertedListAppend(Tensor<int, 1, true>& listIds,
IndicesOptions indicesOptions,
cudaStream_t stream);
/// IVF flat storage
void runIVFFlatInvertedListAppend(Tensor<int, 1, true>& listIds,
Tensor<int, 1, true>& listOffset,
Tensor<float, 2, true>& vecs,
Tensor<Index::idx_t, 1, true>& indices,
bool useResidual,
Tensor<float, 2, true>& residuals,
GpuScalarQuantizer* scalarQ,
thrust::device_vector<void*>& listData,
thrust::device_vector<void*>& listIndices,
IndicesOptions indicesOptions,
cudaStream_t stream);
/// Append SQ codes to IVF lists (non-interleaved, old format)
void runIVFFlatAppend(Tensor<int, 1, true>& listIds,
Tensor<int, 1, true>& listOffset,
Tensor<float, 2, true>& vecs,
Tensor<Index::idx_t, 1, true>& indices,
GpuScalarQuantizer* scalarQ,
thrust::device_vector<void*>& listData,
thrust::device_vector<void*>& listIndices,
IndicesOptions indicesOptions,
cudaStream_t stream);
/// Append SQ codes to IVF lists (interleaved)
void runIVFFlatInterleavedAppend(Tensor<int, 1, true>& listIds,
Tensor<int, 1, true>& listOffset,
Tensor<int, 1, true>& uniqueLists,
Tensor<int, 1, true>& vectorsByUniqueList,
Tensor<int, 1, true>& uniqueListVectorStart,
Tensor<int, 1, true>& uniqueListStartOffset,
Tensor<float, 2, true>& vecs,
Tensor<Index::idx_t, 1, true>& indices,
GpuScalarQuantizer* scalarQ,
thrust::device_vector<void*>& listData,
thrust::device_vector<void*>& listIndices,
IndicesOptions indicesOptions,
cudaStream_t stream);
} } // namespace

View File

@ -12,6 +12,7 @@
#include <faiss/gpu/impl/FlatIndex.cuh>
#include <faiss/gpu/impl/IVFAppend.cuh>
#include <faiss/gpu/impl/RemapIndices.h>
#include <faiss/gpu/utils/CopyUtils.cuh>
#include <faiss/gpu/utils/DeviceDefs.cuh>
#include <faiss/gpu/utils/DeviceUtils.h>
#include <faiss/gpu/utils/HostTensor.cuh>
@ -30,6 +31,7 @@ IVFBase::IVFBase(GpuResources* resources,
faiss::MetricType metric,
float metricArg,
FlatIndex* quantizer,
bool interleavedLayout,
IndicesOptions indicesOptions,
MemorySpace space) :
resources_(resources),
@ -38,6 +40,7 @@ IVFBase::IVFBase(GpuResources* resources,
quantizer_(quantizer),
dim_(quantizer->getDim()),
numLists_(quantizer->getSize()),
interleavedLayout_(interleavedLayout),
indicesOptions_(indicesOptions),
space_(space),
maxListLength_(0) {
@ -267,7 +270,7 @@ IVFBase::getListIndices(int listId) const {
}
std::vector<uint8_t>
IVFBase::getListVectorData(int listId) const {
IVFBase::getListVectorData(int listId, bool gpuFormat) const {
FAISS_THROW_IF_NOT_FMT(listId < numLists_,
"IVF list %d is out of bounds (%d lists total)",
listId, numLists_);
@ -279,9 +282,13 @@ IVFBase::getListVectorData(int listId) const {
auto& list = deviceListData_[listId];
auto gpuCodes = list->data.copyToHost<uint8_t>(stream);
// The GPU layout may be different than the CPU layout (e.g., vectors rather
// than dimensions interleaved), translate back if necessary
return translateCodesFromGpu_(std::move(gpuCodes), list->numVecs);
if (gpuFormat) {
return gpuCodes;
} else {
// The GPU layout may be different than the CPU layout (e.g., vectors rather
// than dimensions interleaved), translate back if necessary
return translateCodesFromGpu_(std::move(gpuCodes), list->numVecs);
}
}
void
@ -306,7 +313,7 @@ void
IVFBase::copyInvertedListsTo(InvertedLists* ivf) {
for (int i = 0; i < numLists_; ++i) {
auto listIndices = getListIndices(i);
auto listData = getListVectorData(i);
auto listData = getListVectorData(i, false);
ivf->add_entries(i, listIndices.size(), listIndices.data(), listData.data());
}
@ -413,7 +420,6 @@ IVFBase::addIndicesFromCpu_(int listId,
deviceListIndexPointers_[listId] = listIndices->data.data();
}
int
IVFBase::addVectors(Tensor<float, 2, true>& vecs,
Tensor<Index::idx_t, 1, true>& indices) {
@ -422,9 +428,6 @@ IVFBase::addVectors(Tensor<float, 2, true>& vecs,
auto stream = resources_->getDefaultStreamCurrentDevice();
// Number of valid vectors that we actually add; we return this
int numAdded = 0;
// Determine which IVF lists we need to append to
// We don't actually need this
@ -433,7 +436,6 @@ IVFBase::addVectors(Tensor<float, 2, true>& vecs,
// We use this
DeviceTensor<int, 2, true> listIds2d(
resources_, makeTempAlloc(AllocType::Other, stream), {vecs.getSize(0), 1});
auto listIds = listIds2d.view<1>({vecs.getSize(0)});
quantizer_->query(vecs, 1, metric_, metricArg_,
listDistance, listIds2d, false);
@ -442,39 +444,47 @@ IVFBase::addVectors(Tensor<float, 2, true>& vecs,
// FIXME: really this can be into pinned memory and a true async
// copy on a different stream; we can start the copy early, but it's
// tiny
HostTensor<int, 1, true> listIdsHost(listIds, stream);
auto listIdsHost = listIds2d.copyToVector(stream);
// Now we add the encoded vectors to the individual lists
// First, make sure that there is space available for adding the new
// encoded vectors and indices
// list id -> # being added
std::unordered_map<int, int> assignCounts;
// list id -> vectors being added
std::unordered_map<int, std::vector<int>> listToVectorIds;
// vector id -> which list it is being appended to
std::vector<int> vectorIdToList(vecs.getSize(0));
// vector id -> offset in list
// (we already have vector id -> list id in listIds)
HostTensor<int, 1, true> listOffsetHost({listIdsHost.getSize(0)});
std::vector<int> listOffsetHost(listIdsHost.size());
for (int i = 0; i < listIdsHost.getSize(0); ++i) {
// Number of valid vectors that we actually add; we return this
int numAdded = 0;
for (int i = 0; i < listIdsHost.size(); ++i) {
int listId = listIdsHost[i];
// Add vector could be invalid (contains NaNs etc)
if (listId < 0) {
listOffsetHost[i] = -1;
vectorIdToList[i] = -1;
continue;
}
FAISS_ASSERT(listId < numLists_);
++numAdded;
vectorIdToList[i] = listId;
int offset = deviceListData_[listId]->numVecs;
auto it = assignCounts.find(listId);
if (it != assignCounts.end()) {
offset += it->second;
it->second++;
auto it = listToVectorIds.find(listId);
if (it != listToVectorIds.end()) {
offset += it->second.size();
it->second.push_back(i);
} else {
assignCounts[listId] = 1;
listToVectorIds[listId] = std::vector<int>{i};
}
listOffsetHost[i] = offset;
@ -486,35 +496,84 @@ IVFBase::addVectors(Tensor<float, 2, true>& vecs,
return 0;
}
// unique lists being added to
std::vector<int> uniqueLists;
for (auto& vecs : listToVectorIds) {
uniqueLists.push_back(vecs.first);
}
std::sort(uniqueLists.begin(), uniqueLists.end());
// In the same order as uniqueLists, list the vectors being added to that list
// contiguously
// (unique list 0 vectors ...)(unique list 1 vectors ...) ...
std::vector<int> vectorsByUniqueList;
// For each of the unique lists, the start offset in vectorsByUniqueList
std::vector<int> uniqueListVectorStart;
// For each of the unique lists, where we start appending in that list by the
// vector offset
std::vector<int> uniqueListStartOffset;
// For each of the unique lists, find the vectors which should be appended to
// that list
for (auto ul : uniqueLists) {
uniqueListVectorStart.push_back(vectorsByUniqueList.size());
FAISS_ASSERT(listToVectorIds.count(ul) != 0);
// The vectors we are adding to this list
auto& vecs = listToVectorIds[ul];
vectorsByUniqueList.insert(
vectorsByUniqueList.end(), vecs.begin(), vecs.end());
// How many vectors we previously had (which is where we start appending on
// the device)
uniqueListStartOffset.push_back(deviceListData_[ul]->numVecs);
}
// We terminate uniqueListVectorStart with the overall number of vectors being
// added, which could be different than vecs.getSize(0) as some vectors could
// be invalid
uniqueListVectorStart.push_back(vectorsByUniqueList.size());
// We need to resize the data structures for the inverted lists on
// the GPUs, which means that they might need reallocation, which
// means that their base address may change. Figure out the new base
// addresses, and update those in a batch on the device
{
// Resize all of the lists that we are appending to
for (auto& counts : assignCounts) {
auto& codes = deviceListData_[counts.first];
codes->data.resize(
getGpuVectorsEncodingSize_(codes->numVecs + counts.second), stream);
for (auto& counts : listToVectorIds) {
auto listId = counts.first;
int numVecsToAdd = counts.second.size();
int newNumVecs = codes->numVecs + counts.second;
auto& codes = deviceListData_[listId];
int oldNumVecs = codes->numVecs;
int newNumVecs = codes->numVecs + numVecsToAdd;
auto newSizeBytes = getGpuVectorsEncodingSize_(newNumVecs);
codes->data.resize(newSizeBytes, stream);
codes->numVecs = newNumVecs;
auto& indices = deviceListIndices_[counts.first];
auto& indices = deviceListIndices_[listId];
if ((indicesOptions_ == INDICES_32_BIT) ||
(indicesOptions_ == INDICES_64_BIT)) {
size_t indexSize =
(indicesOptions_ == INDICES_32_BIT) ? sizeof(int) : sizeof(Index::idx_t);
(indicesOptions_ == INDICES_32_BIT) ?
sizeof(int) : sizeof(Index::idx_t);
indices->data.resize(
indices->data.size() + counts.second * indexSize, stream);
indices->data.size() + numVecsToAdd * indexSize, stream);
FAISS_ASSERT(indices->numVecs == oldNumVecs);
indices->numVecs = newNumVecs;
} else if (indicesOptions_ == INDICES_CPU) {
// indices are stored on the CPU side
FAISS_ASSERT(counts.first < listOffsetToUserIndex_.size());
FAISS_ASSERT(listId < listOffsetToUserIndex_.size());
auto& userIndices = listOffsetToUserIndex_[counts.first];
auto& userIndices = listOffsetToUserIndex_[listId];
userIndices.resize(newNumVecs);
} else {
// indices are not stored on the GPU or CPU side
@ -528,15 +587,7 @@ IVFBase::addVectors(Tensor<float, 2, true>& vecs,
// Update all pointers and sizes on the device for lists that we
// appended to
{
std::vector<int> listIdsV(assignCounts.size());
int i = 0;
for (auto& counts : assignCounts) {
listIdsV[i++] = counts.first;
}
updateDeviceListInfo_(listIdsV, stream);
}
updateDeviceListInfo_(uniqueLists, stream);
}
// If we're maintaining the indices on the CPU side, update our
@ -554,6 +605,7 @@ IVFBase::addVectors(Tensor<float, 2, true>& vecs,
}
int offset = listOffsetHost[i];
FAISS_ASSERT(offset >= 0);
FAISS_ASSERT(listId < listOffsetToUserIndex_.size());
auto& userIndices = listOffsetToUserIndex_[listId];
@ -564,15 +616,31 @@ IVFBase::addVectors(Tensor<float, 2, true>& vecs,
}
// Copy the offsets to the GPU
DeviceTensor<int, 1, true> listOffset(
resources_, makeTempAlloc(AllocType::Other, stream), listOffsetHost);
auto listIdsDevice = listIds2d.downcastOuter<1>();
auto listOffsetDevice =
toDeviceTemporary(resources_, listOffsetHost, stream);
auto uniqueListsDevice =
toDeviceTemporary(resources_, uniqueLists, stream);
auto vectorsByUniqueListDevice =
toDeviceTemporary(resources_, vectorsByUniqueList, stream);
auto uniqueListVectorStartDevice =
toDeviceTemporary(resources_, uniqueListVectorStart, stream);
auto uniqueListStartOffsetDevice =
toDeviceTemporary(resources_, uniqueListStartOffset, stream);
// Actually encode and append the vectors
appendVectors_(vecs, indices, listIds, listOffset, stream);
appendVectors_(vecs,
indices,
uniqueListsDevice,
vectorsByUniqueListDevice,
uniqueListVectorStartDevice,
uniqueListStartOffsetDevice,
listIdsDevice,
listOffsetDevice,
stream);
// We added this number
return numAdded;
}
} } // namespace

View File

@ -32,6 +32,7 @@ class IVFBase {
float metricArg,
/// We do not own this reference
FlatIndex* quantizer,
bool interleavedLayout,
IndicesOptions indicesOptions,
MemorySpace space);
@ -62,7 +63,7 @@ class IVFBase {
std::vector<Index::idx_t> getListIndices(int listId) const;
/// Return the encoded vectors of a particular list back to the CPU
std::vector<uint8_t> getListVectorData(int listId) const;
std::vector<uint8_t> getListVectorData(int listId, bool gpuFormat) const;
/// Copy all inverted lists from a CPU representation to ourselves
void copyInvertedListsFrom(const InvertedLists* ivf);
@ -105,6 +106,10 @@ class IVFBase {
/// Append vectors to our on-device lists
virtual void appendVectors_(Tensor<float, 2, true>& vecs,
Tensor<Index::idx_t, 1, true>& indices,
Tensor<int, 1, true>& uniqueLists,
Tensor<int, 1, true>& vectorsByUniqueList,
Tensor<int, 1, true>& uniqueListVectorStart,
Tensor<int, 1, true>& uniqueListStartOffset,
Tensor<int, 1, true>& listIds,
Tensor<int, 1, true>& listOffset,
cudaStream_t stream) = 0;
@ -145,6 +150,17 @@ class IVFBase {
/// Number of inverted lists we maintain
const int numLists_;
/// Whether or not our index uses an interleaved by 32 layout:
/// The default memory layout is [vector][PQ/SQ component]:
/// (v0 d0) (v0 d1) ... (v0 dD-1) (v1 d0) (v1 d1) ...
///
/// The interleaved by 32 memory layout is:
/// [vector / 32][PQ/SQ component][vector % 32] with padding:
/// (v0 d0) (v1 d0) ... (v31 d0) (v0 d1) (v1 d1) ... (v31 dD-1) (v32 d0) (v33
/// d0) ...
/// so the list length is always a multiple of num quantizers * 32
bool interleavedLayout_;
/// How are user indices stored on the GPU?
const IndicesOptions indicesOptions_;

View File

@ -9,8 +9,10 @@
#include <faiss/gpu/impl/IVFFlat.cuh>
#include <faiss/gpu/GpuResources.h>
#include <faiss/gpu/impl/FlatIndex.cuh>
#include <faiss/gpu/impl/InterleavedCodes.h>
#include <faiss/gpu/impl/IVFAppend.cuh>
#include <faiss/gpu/impl/IVFFlatScan.cuh>
#include <faiss/gpu/impl/IVFInterleaved.cuh>
#include <faiss/gpu/impl/RemapIndices.h>
#include <faiss/gpu/utils/ConversionOperators.cuh>
#include <faiss/gpu/utils/CopyUtils.cuh>
@ -31,12 +33,14 @@ IVFFlat::IVFFlat(GpuResources* res,
float metricArg,
bool useResidual,
faiss::ScalarQuantizer* scalarQ,
bool interleavedLayout,
IndicesOptions indicesOptions,
MemorySpace space) :
IVFBase(res,
metric,
metricArg,
quantizer,
interleavedLayout,
indicesOptions,
space),
useResidual_(useResidual),
@ -48,35 +52,72 @@ IVFFlat::~IVFFlat() {
size_t
IVFFlat::getGpuVectorsEncodingSize_(int numVecs) const {
return (size_t) numVecs *
// size per encoded vector
(scalarQ_ ? scalarQ_->code_size : sizeof(float) * getDim());
if (interleavedLayout_) {
// bits per scalar code
int bits = scalarQ_ ? scalarQ_->bits : 32 /* float */;
// bytes to encode a block of 32 vectors (single dimension)
int bytesPerDimBlock = bits * 32 / 8;
// bytes to fully encode 32 vectors
int bytesPerBlock = bytesPerDimBlock * dim_;
// number of blocks of 32 vectors we have
int numBlocks = utils::divUp(numVecs, 32);
// total size to encode numVecs
return bytesPerBlock * numBlocks;
} else {
size_t sizePerVector =
(scalarQ_ ? scalarQ_->code_size : sizeof(float) * dim_);
return (size_t) numVecs * sizePerVector;
}
}
size_t
IVFFlat::getCpuVectorsEncodingSize_(int numVecs) const {
return (size_t) numVecs *
// size per encoded vector
(scalarQ_ ? scalarQ_->code_size : sizeof(float) * getDim());
size_t sizePerVector =
(scalarQ_ ? scalarQ_->code_size : sizeof(float) * dim_);
return (size_t) numVecs * sizePerVector;
}
std::vector<uint8_t>
IVFFlat::translateCodesToGpu_(std::vector<uint8_t> codes,
size_t numVecs) const {
// nothing to do
return codes;
if (!interleavedLayout_) {
// same format
return codes;
}
int bitsPerCode = scalarQ_ ? scalarQ_->bits : 32;
auto up = unpackNonInterleaved(std::move(codes), numVecs, dim_, bitsPerCode);
return packInterleaved(std::move(up), numVecs, dim_, bitsPerCode);
}
std::vector<uint8_t>
IVFFlat::translateCodesFromGpu_(std::vector<uint8_t> codes,
size_t numVecs) const {
// nothing to do
return codes;
if (!interleavedLayout_) {
// same format
return codes;
}
int bitsPerCode = scalarQ_ ? scalarQ_->bits : 32;
auto up = unpackInterleaved(std::move(codes), numVecs, dim_, bitsPerCode);
return packNonInterleaved(std::move(up), numVecs, dim_, bitsPerCode);
}
void
IVFFlat::appendVectors_(Tensor<float, 2, true>& vecs,
Tensor<Index::idx_t, 1, true>& indices,
Tensor<int, 1, true>& uniqueLists,
Tensor<int, 1, true>& vectorsByUniqueList,
Tensor<int, 1, true>& uniqueListVectorStart,
Tensor<int, 1, true>& uniqueListStartOffset,
Tensor<int, 1, true>& listIds,
Tensor<int, 1, true>& listOffset,
cudaStream_t stream) {
@ -93,17 +134,31 @@ IVFFlat::appendVectors_(Tensor<float, 2, true>& vecs,
}
// Now, for each list to which a vector is being assigned, write it
runIVFFlatInvertedListAppend(listIds,
listOffset,
vecs,
indices,
useResidual_,
residuals,
scalarQ_.get(),
deviceListDataPointers_,
deviceListIndexPointers_,
indicesOptions_,
stream);
if (interleavedLayout_) {
runIVFFlatInterleavedAppend(listIds,
listOffset,
uniqueLists,
vectorsByUniqueList,
uniqueListVectorStart,
uniqueListStartOffset,
useResidual_ ? residuals : vecs,
indices,
scalarQ_.get(),
deviceListDataPointers_,
deviceListIndexPointers_,
indicesOptions_,
stream);
} else {
runIVFFlatAppend(listIds,
listOffset,
useResidual_ ? residuals : vecs,
indices,
scalarQ_.get(),
deviceListDataPointers_,
deviceListIndexPointers_,
indicesOptions_,
stream);
}
}
void
@ -149,21 +204,38 @@ IVFFlat::query(Tensor<float, 2, true>& queries,
quantizer_->reconstruct(coarseIndices, residualBase);
}
runIVFFlatScan(queries,
coarseIndices,
deviceListDataPointers_,
deviceListIndexPointers_,
indicesOptions_,
deviceListLengths_,
maxListLength_,
k,
metric_,
useResidual_,
residualBase,
scalarQ_.get(),
outDistances,
outIndices,
resources_);
if (interleavedLayout_) {
runIVFInterleavedScan(queries,
coarseIndices,
deviceListDataPointers_,
deviceListIndexPointers_,
indicesOptions_,
deviceListLengths_,
k,
metric_,
useResidual_,
residualBase,
scalarQ_.get(),
outDistances,
outIndices,
resources_);
} else {
runIVFFlatScan(queries,
coarseIndices,
deviceListDataPointers_,
deviceListIndexPointers_,
indicesOptions_,
deviceListLengths_,
maxListLength_,
k,
metric_,
useResidual_,
residualBase,
scalarQ_.get(),
outDistances,
outIndices,
resources_);
}
// If the GPU isn't storing indices (they are on the CPU side), we
// need to perform the re-mapping here

View File

@ -24,6 +24,7 @@ class IVFFlat : public IVFBase {
bool useResidual,
/// Optional ScalarQuantizer
faiss::ScalarQuantizer* scalarQ,
bool interleavedLayout,
IndicesOptions indicesOptions,
MemorySpace space);
@ -56,6 +57,10 @@ class IVFFlat : public IVFBase {
/// Encode the vectors that we're adding and append to our IVF lists
void appendVectors_(Tensor<float, 2, true>& vecs,
Tensor<Index::idx_t, 1, true>& indices,
Tensor<int, 1, true>& uniqueLists,
Tensor<int, 1, true>& vectorsByUniqueList,
Tensor<int, 1, true>& uniqueListVectorStart,
Tensor<int, 1, true>& uniqueListStartOffset,
Tensor<int, 1, true>& listIds,
Tensor<int, 1, true>& listOffset,
cudaStream_t stream) override;

View File

@ -10,17 +10,16 @@
#include <faiss/gpu/impl/DistanceUtils.cuh>
#include <faiss/gpu/impl/IVFUtils.cuh>
#include <faiss/gpu/GpuResources.h>
#include <faiss/gpu/utils/Comparators.cuh>
#include <faiss/gpu/utils/ConversionOperators.cuh>
#include <faiss/gpu/utils/DeviceDefs.cuh>
#include <faiss/gpu/utils/DeviceUtils.h>
#include <faiss/gpu/utils/DeviceTensor.cuh>
#include <faiss/gpu/utils/Float16.cuh>
#include <faiss/gpu/utils/MathOperators.cuh>
#include <faiss/gpu/utils/LoadStoreOperators.cuh>
#include <faiss/gpu/utils/PtxUtils.cuh>
#include <faiss/gpu/utils/Reductions.cuh>
#include <faiss/gpu/utils/StaticUtils.h>
#include <thrust/host_vector.h>
namespace faiss { namespace gpu {
@ -162,7 +161,8 @@ ivfFlatScan(Tensor<float, 2, true> queries,
auto residualBaseSlice = residualBase[queryId][probeId].data();
codec.setSmem(smem, dim);
codec.initKernel(smem, dim);
__syncthreads();
IVFFlatScan<Codec, Metric>::scan(query,
useResidual,
@ -253,51 +253,25 @@ runIVFFlatScanTile(GpuResources* res,
switch (scalarQ->qtype) {
case ScalarQuantizer::QuantizerType::QT_8bit:
{
// FIXME: investigate 32 bit load perf issues
// if (dim % 4 == 0) {
if (false) {
Codec<ScalarQuantizer::QuantizerType::QT_8bit, 4>
codec(scalarQ->code_size,
scalarQ->gpuTrained.data(),
scalarQ->gpuTrained.data() + dim);
HANDLE_METRICS;
} else {
Codec<ScalarQuantizer::QuantizerType::QT_8bit, 1>
codec(scalarQ->code_size,
scalarQ->gpuTrained.data(),
scalarQ->gpuTrained.data() + dim);
HANDLE_METRICS;
}
Codec<ScalarQuantizer::QuantizerType::QT_8bit, 1>
codec(scalarQ->code_size,
scalarQ->gpuTrained.data(),
scalarQ->gpuTrained.data() + dim);
HANDLE_METRICS;
}
break;
case ScalarQuantizer::QuantizerType::QT_8bit_uniform:
{
// FIXME: investigate 32 bit load perf issues
if (false) {
// if (dim % 4 == 0) {
Codec<ScalarQuantizer::QuantizerType::QT_8bit_uniform, 4>
codec(scalarQ->code_size, scalarQ->trained[0], scalarQ->trained[1]);
HANDLE_METRICS;
} else {
Codec<ScalarQuantizer::QuantizerType::QT_8bit_uniform, 1>
codec(scalarQ->code_size, scalarQ->trained[0], scalarQ->trained[1]);
HANDLE_METRICS;
}
Codec<ScalarQuantizer::QuantizerType::QT_8bit_uniform, 1>
codec(scalarQ->code_size, scalarQ->trained[0], scalarQ->trained[1]);
HANDLE_METRICS;
}
break;
case ScalarQuantizer::QuantizerType::QT_fp16:
{
if (false) {
// FIXME: investigate 32 bit load perf issues
// if (dim % 2 == 0) {
Codec<ScalarQuantizer::QuantizerType::QT_fp16, 2>
codec(scalarQ->code_size);
HANDLE_METRICS;
} else {
Codec<ScalarQuantizer::QuantizerType::QT_fp16, 1>
codec(scalarQ->code_size);
HANDLE_METRICS;
}
Codec<ScalarQuantizer::QuantizerType::QT_fp16, 1>
codec(scalarQ->code_size);
HANDLE_METRICS;
}
break;
case ScalarQuantizer::QuantizerType::QT_8bit_direct:
@ -384,7 +358,6 @@ runIVFFlatScan(Tensor<float, 2, true>& queries,
constexpr int kThrustMemSize = 16384;
int nprobe = listIds.getSize(1);
auto stream = res->getDefaultStreamCurrentDevice();
// Make a reservation for Thrust to do its dirty work (global memory

View File

@ -0,0 +1,205 @@
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <faiss/gpu/impl/IVFInterleaved.cuh>
#include <faiss/gpu/impl/scan/IVFInterleavedImpl.cuh>
namespace faiss { namespace gpu {
constexpr uint32_t kMaxUInt32 = std::numeric_limits<uint32_t>::max();
// Second-pass kernel to further k-select the results from the first pass across
// IVF lists and produce the final results
template <int ThreadsPerBlock, int NumWarpQ, int NumThreadQ>
__global__ void
ivfInterleavedScan2(Tensor<float, 3, true> distanceIn,
Tensor<int, 3, true> indicesIn,
Tensor<int, 2, true> listIds,
int k,
void** listIndices,
IndicesOptions opt,
bool dir,
Tensor<float, 2, true> distanceOut,
Tensor<Index::idx_t, 2, true> indicesOut) {
int queryId = blockIdx.x;
constexpr int kNumWarps = ThreadsPerBlock / kWarpSize;
__shared__ float smemK[kNumWarps * NumWarpQ];
__shared__ uint32_t smemV[kNumWarps * NumWarpQ];
// To avoid creating excessive specializations, we combine direction kernels,
// selecting for the smallest element. If `dir` is true, we negate all values
// being selected (so that we are selecting the largest element).
BlockSelect<float, uint32_t, false, Comparator<float>,
NumWarpQ, NumThreadQ, ThreadsPerBlock>
heap(kFloatMax, kMaxUInt32, smemK, smemV, k);
// nprobe x k
int num = distanceIn.getSize(1) * distanceIn.getSize(2);
auto distanceBase = distanceIn[queryId].data();
int limit = utils::roundDown(num, kWarpSize);
// This will keep our negation factor
float adj = dir ? -1 : 1;
int i = threadIdx.x;
for (; i < limit; i += blockDim.x) {
// We represent the index as (probe id)(k)
// Right now, both are limited to a maximum of 2048, but we will dedicate
// each to the high and low words of a uint32_t
static_assert(GPU_MAX_SELECTION_K <= 65536, "");
uint32_t curProbe = i / k;
uint32_t curK = i % k;
uint32_t index = (curProbe << 16) | (curK & (uint32_t) 0xffff);
int listId = listIds[queryId][curProbe];
if (listId != -1) {
// Adjust the value we are selecting based on the sorting order
heap.addThreadQ(distanceBase[i] * adj, index);
}
heap.checkThreadQ();
}
// Handle warp divergence separately
if (i < num) {
uint32_t curProbe = i / k;
uint32_t curK = i % k;
uint32_t index = (curProbe << 16) | (curK & (uint32_t) 0xffff);
int listId = listIds[queryId][curProbe];
if (listId != -1) {
heap.addThreadQ(distanceBase[i] * adj, index);
}
}
// Merge all final results
heap.reduce();
for (int i = threadIdx.x; i < k; i += blockDim.x) {
// Re-adjust the value we are selecting based on the sorting order
distanceOut[queryId][i] = smemK[i] * adj;
auto packedIndex = smemV[i];
// We need to remap to the user-provided indices
Index::idx_t index = -1;
// We may not have at least k values to return; in this function, max uint32
// is our sentinel value
if (packedIndex != kMaxUInt32) {
uint32_t curProbe = packedIndex >> 16;
uint32_t curK = packedIndex & 0xffff;
int listId = listIds[queryId][curProbe];
int listOffset = indicesIn[queryId][curProbe][curK];
if (opt == INDICES_32_BIT) {
index = (Index::idx_t) ((int*) listIndices[listId])[listOffset];
} else if (opt == INDICES_64_BIT) {
index = ((Index::idx_t*) listIndices[listId])[listOffset];
} else {
index = ((Index::idx_t) listId << 32 | (Index::idx_t) listOffset);
}
}
indicesOut[queryId][i] = index;
}
}
void
runIVFInterleavedScan2(Tensor<float, 3, true>& distanceIn,
Tensor<int, 3, true>& indicesIn,
Tensor<int, 2, true>& listIds,
int k,
thrust::device_vector<void*>& listIndices,
IndicesOptions indicesOptions,
bool dir,
Tensor<float, 2, true>& distanceOut,
Tensor<Index::idx_t, 2, true>& indicesOut,
cudaStream_t stream) {
#define IVF_SCAN_2(THREADS, NUM_WARP_Q, NUM_THREAD_Q) \
ivfInterleavedScan2<THREADS, NUM_WARP_Q, NUM_THREAD_Q> \
<<<distanceIn.getSize(0), THREADS, 0, stream>>>( \
distanceIn, \
indicesIn, \
listIds, \
k, \
listIndices.data().get(), \
indicesOptions, \
dir, \
distanceOut, \
indicesOut)
if (k == 1) {
IVF_SCAN_2(128, 1, 1);
} else if (k <= 32) {
IVF_SCAN_2(128, 32, 2);
} else if (k <= 64) {
IVF_SCAN_2(128, 64, 3);
} else if (k <= 128) {
IVF_SCAN_2(128, 128, 3);
} else if (k <= 256) {
IVF_SCAN_2(128, 256, 4);
} else if (k <= 512) {
IVF_SCAN_2(128, 512, 8);
} else if (k <= 1024) {
IVF_SCAN_2(128, 1024, 8);
}
#if GPU_MAX_SELECTION_K >= 2048
else if (k <= 2048) {
IVF_SCAN_2(64, 2048, 8);
}
#endif
}
void
runIVFInterleavedScan(Tensor<float, 2, true>& queries,
Tensor<int, 2, true>& listIds,
thrust::device_vector<void*>& listData,
thrust::device_vector<void*>& listIndices,
IndicesOptions indicesOptions,
thrust::device_vector<int>& listLengths,
int k,
faiss::MetricType metric,
bool useResidual,
Tensor<float, 3, true>& residualBase,
GpuScalarQuantizer* scalarQ,
// output
Tensor<float, 2, true>& outDistances,
// output
Tensor<Index::idx_t, 2, true>& outIndices,
GpuResources* res) {
// caught for exceptions at a higher level
FAISS_ASSERT(k <= GPU_MAX_SELECTION_K);
if (k == 1) {
IVF_INTERLEAVED_CALL(1);
} else if (k <= 32) {
IVF_INTERLEAVED_CALL(32);
} else if (k <= 64) {
IVF_INTERLEAVED_CALL(64);
} else if (k <= 128) {
IVF_INTERLEAVED_CALL(128);
} else if (k <= 256) {
IVF_INTERLEAVED_CALL(256);
} else if (k <= 512) {
IVF_INTERLEAVED_CALL(512);
} else if (k <= 1024) {
IVF_INTERLEAVED_CALL(1024);
}
#if GPU_MAX_SELECTION_K >= 2048
else if (k <= 2048) {
IVF_INTERLEAVED_CALL(2048);
}
#endif
}
} } // namespace

View File

@ -0,0 +1,383 @@
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
#pragma once
#include <faiss/MetricType.h>
#include <faiss/gpu/GpuIndicesOptions.h>
#include <faiss/gpu/GpuResources.h>
#include <faiss/gpu/impl/DistanceUtils.cuh>
#include <faiss/gpu/impl/GpuScalarQuantizer.cuh>
#include <faiss/gpu/utils/Comparators.cuh>
#include <faiss/gpu/utils/ConversionOperators.cuh>
#include <faiss/gpu/utils/DeviceDefs.cuh>
#include <faiss/gpu/utils/DeviceTensor.cuh>
#include <faiss/gpu/utils/DeviceUtils.h>
#include <faiss/gpu/utils/Float16.cuh>
#include <faiss/gpu/utils/MathOperators.cuh>
#include <faiss/gpu/utils/PtxUtils.cuh>
#include <faiss/gpu/utils/Select.cuh>
#include <faiss/gpu/utils/StaticUtils.h>
#include <faiss/gpu/utils/WarpPackedBits.cuh>
#include <thrust/device_vector.h>
namespace faiss { namespace gpu {
/// First pass kernel to perform scanning of IVF lists to produce top-k
/// candidates
template <typename Codec,
typename Metric,
int ThreadsPerBlock,
int NumWarpQ,
int NumThreadQ,
bool Residual>
__global__ void
ivfInterleavedScan(Tensor<float, 2, true> queries,
Tensor<float, 3, true> residualBase,
Tensor<int, 2, true> listIds,
void** allListData,
int* listLengths,
Codec codec,
Metric metric,
int k,
// [query][probe][k]
Tensor<float, 3, true> distanceOut,
Tensor<int, 3, true> indicesOut) {
extern __shared__ float smem[];
constexpr int kNumWarps = ThreadsPerBlock / kWarpSize;
int queryId = blockIdx.y;
int probeId = blockIdx.x;
int listId = listIds[queryId][probeId];
// Safety guard in case NaNs in input cause no list ID to be generated, or we
// have more nprobe than nlist
if (listId == -1) {
return;
}
int dim = queries.getSize(1);
// FIXME: some issue with getLaneId() and CUDA 10.1 and P4 GPUs?
int laneId = threadIdx.x % kWarpSize;
int warpId = threadIdx.x / kWarpSize;
using EncodeT = typename Codec::EncodeT;
auto query = queries[queryId].data();
auto vecsBase = (EncodeT*) allListData[listId];
int numVecs = listLengths[listId];
auto residualBaseSlice = residualBase[queryId][probeId].data();
constexpr auto kInit = Metric::kDirection ? kFloatMin : kFloatMax;
__shared__ float smemK[kNumWarps * NumWarpQ];
__shared__ int smemV[kNumWarps * NumWarpQ];
BlockSelect<float, int, Metric::kDirection, Comparator<float>,
NumWarpQ, NumThreadQ, ThreadsPerBlock>
heap(kInit, -1, smemK, smemV, k);
// The codec might be dependent upon data that we need to reference or store
// in shared memory
codec.initKernel(smem, dim);
__syncthreads();
// How many vector blocks of 32 are in this list?
int numBlocks = utils::divUp(numVecs, 32);
// Number of EncodeT words per each dimension of block of 32 vecs
constexpr int bytesPerVectorBlockDim =
Codec::kEncodeBits * 32 / 8;
constexpr int wordsPerVectorBlockDim =
bytesPerVectorBlockDim / sizeof(EncodeT);
int wordsPerVectorBlock = wordsPerVectorBlockDim * dim;
int dimBlocks = utils::roundDown(dim, kWarpSize);
for (int block = warpId; block < numBlocks; block += kNumWarps) {
// We're handling a new vector
Metric dist = metric.zero();
// This is the vector a given lane/thread handles
int vec = block * kWarpSize + laneId;
bool valid = vec < numVecs;
// This is where this warp begins reading data
EncodeT* data = vecsBase + block * wordsPerVectorBlock;
// whole blocks
for (int dBase = 0; dBase < dimBlocks; dBase += kWarpSize) {
int loadDim = dBase + laneId;
float queryReg = query[loadDim];
float residualReg = Residual ? residualBaseSlice[loadDim] : 0;
constexpr int kUnroll = 4;
#pragma unroll
for (int i = 0; i < kWarpSize / kUnroll;
++i, data += kUnroll * wordsPerVectorBlockDim) {
EncodeT encV[kUnroll];
#pragma unroll
for (int j = 0; j < kUnroll; ++j) {
encV[j] = WarpPackedBits<EncodeT, Codec::kEncodeBits>::
read(laneId, data + j * wordsPerVectorBlockDim);
}
#pragma unroll
for (int j = 0; j < kUnroll; ++j) {
encV[j] = WarpPackedBits<EncodeT, Codec::kEncodeBits>::
postRead(laneId, encV[j]);
}
float decV[kUnroll];
#pragma unroll
for (int j = 0; j < kUnroll; ++j) {
int d = i * kUnroll + j;
decV[j] = codec.decodeNew(dBase + d, encV[j]);
}
if (Residual) {
#pragma unroll
for (int j = 0; j < kUnroll; ++j) {
int d = i * kUnroll + j;
decV[j] += SHFL_SYNC(residualReg, d, kWarpSize);
}
}
#pragma unroll
for (int j = 0; j < kUnroll; ++j) {
int d = i * kUnroll + j;
float q = SHFL_SYNC(queryReg, d, kWarpSize);
dist.handle(q, decV[j]);
}
}
}
// remainder
int loadDim = dimBlocks + laneId;
bool loadDimInBounds = loadDim < dim;
float queryReg = loadDimInBounds ? query[loadDim] : 0;
float residualReg = Residual && loadDimInBounds ? residualBaseSlice[loadDim] : 0;
for (int d = 0; d < dim - dimBlocks; ++d, data += wordsPerVectorBlockDim) {
float q = SHFL_SYNC(queryReg, d, kWarpSize);
EncodeT enc =
WarpPackedBits<EncodeT, Codec::kEncodeBits>::read(laneId, data);
enc =
WarpPackedBits<EncodeT, Codec::kEncodeBits>::postRead(laneId, enc);
float dec = codec.decodeNew(dimBlocks + d, enc);
if (Residual) {
dec += SHFL_SYNC(residualReg, d, kWarpSize);
}
dist.handle(q, dec);
}
if (valid) {
heap.addThreadQ(dist.reduce(), vec);
}
heap.checkThreadQ();
}
heap.reduce();
auto distanceOutBase = distanceOut[queryId][probeId].data();
auto indicesOutBase = indicesOut[queryId][probeId].data();
for (int i = threadIdx.x; i < k; i += blockDim.x) {
distanceOutBase[i] = smemK[i];
indicesOutBase[i] = smemV[i];
}
}
//
// We split up the scan function into multiple compilation units to cut down on
// compile time using these macros to define the function body
//
#define IVFINT_RUN(CODEC_TYPE, METRIC_TYPE, THREADS, NUM_WARP_Q, NUM_THREAD_Q) \
do { \
dim3 grid(nprobe, nq); \
if (useResidual) { \
ivfInterleavedScan<CODEC_TYPE, METRIC_TYPE, THREADS, NUM_WARP_Q, NUM_THREAD_Q, true> \
<<<grid, THREADS, codec.getSmemSize(dim), stream>>>( \
queries, \
residualBase, \
listIds, \
listData.data().get(), \
listLengths.data().get(), \
codec, \
metric, \
k, \
distanceTemp, \
indicesTemp); \
} else { \
ivfInterleavedScan<CODEC_TYPE, METRIC_TYPE, THREADS, NUM_WARP_Q, NUM_THREAD_Q, false> \
<<<grid, THREADS, codec.getSmemSize(dim), stream>>>( \
queries, \
residualBase, \
listIds, \
listData.data().get(), \
listLengths.data().get(), \
codec, \
metric, \
k, \
distanceTemp, \
indicesTemp); \
} \
\
runIVFInterleavedScan2(distanceTemp, \
indicesTemp, \
listIds, \
k, \
listIndices, \
indicesOptions, \
METRIC_TYPE::kDirection, \
outDistances, \
outIndices, \
stream); \
} while (0);
#define IVFINT_CODECS(METRIC_TYPE, THREADS, NUM_WARP_Q, NUM_THREAD_Q) \
do { \
if (!scalarQ) { \
using CodecT = CodecFloat; \
CodecT codec(dim * sizeof(float)); \
IVFINT_RUN(CodecT, METRIC_TYPE, THREADS, NUM_WARP_Q, NUM_THREAD_Q); \
} else { \
switch (scalarQ->qtype) { \
case ScalarQuantizer::QuantizerType::QT_8bit: \
{ \
using CodecT = Codec<ScalarQuantizer::QuantizerType::QT_8bit, 1>; \
CodecT \
codec(scalarQ->code_size, \
scalarQ->gpuTrained.data(), \
scalarQ->gpuTrained.data() + dim); \
IVFINT_RUN(CodecT, METRIC_TYPE, THREADS, NUM_WARP_Q, NUM_THREAD_Q); \
} \
break; \
case ScalarQuantizer::QuantizerType::QT_8bit_uniform: \
{ \
using CodecT = Codec<ScalarQuantizer::QuantizerType::QT_8bit_uniform, 1>; \
CodecT \
codec(scalarQ->code_size, scalarQ->trained[0], scalarQ->trained[1]); \
IVFINT_RUN(CodecT, METRIC_TYPE, THREADS, NUM_WARP_Q, NUM_THREAD_Q); \
} \
break; \
case ScalarQuantizer::QuantizerType::QT_fp16: \
{ \
using CodecT = Codec<ScalarQuantizer::QuantizerType::QT_fp16, 1>; \
CodecT \
codec(scalarQ->code_size); \
IVFINT_RUN(CodecT, METRIC_TYPE, THREADS, NUM_WARP_Q, NUM_THREAD_Q); \
} \
break; \
case ScalarQuantizer::QuantizerType::QT_8bit_direct: \
{ \
using CodecT = Codec<ScalarQuantizer::QuantizerType::QT_8bit_direct, 1>; \
Codec<ScalarQuantizer::QuantizerType::QT_8bit_direct, 1> \
codec(scalarQ->code_size); \
IVFINT_RUN(CodecT, METRIC_TYPE, THREADS, NUM_WARP_Q, NUM_THREAD_Q); \
} \
break; \
case ScalarQuantizer::QuantizerType::QT_6bit: \
{ \
using CodecT = Codec<ScalarQuantizer::QuantizerType::QT_6bit, 1>; \
Codec<ScalarQuantizer::QuantizerType::QT_6bit, 1> \
codec(scalarQ->code_size, \
scalarQ->gpuTrained.data(), \
scalarQ->gpuTrained.data() + dim); \
IVFINT_RUN(CodecT, METRIC_TYPE, THREADS, NUM_WARP_Q, NUM_THREAD_Q); \
} \
break; \
case ScalarQuantizer::QuantizerType::QT_4bit: \
{ \
using CodecT = Codec<ScalarQuantizer::QuantizerType::QT_4bit, 1>; \
Codec<ScalarQuantizer::QuantizerType::QT_4bit, 1> \
codec(scalarQ->code_size, \
scalarQ->gpuTrained.data(), \
scalarQ->gpuTrained.data() + dim); \
IVFINT_RUN(CodecT, METRIC_TYPE, THREADS, NUM_WARP_Q, NUM_THREAD_Q); \
} \
break; \
case ScalarQuantizer::QuantizerType::QT_4bit_uniform: \
{ \
using CodecT = Codec<ScalarQuantizer::QuantizerType::QT_4bit_uniform, 1>; \
Codec<ScalarQuantizer::QuantizerType::QT_4bit_uniform, 1> \
codec(scalarQ->code_size, scalarQ->trained[0], scalarQ->trained[1]); \
IVFINT_RUN(CodecT, METRIC_TYPE, THREADS, NUM_WARP_Q, NUM_THREAD_Q); \
} \
break; \
default: \
FAISS_ASSERT(false); \
} \
} \
} while (0)
#define IVFINT_METRICS(THREADS, NUM_WARP_Q, NUM_THREAD_Q) \
do { \
auto stream = res->getDefaultStreamCurrentDevice(); \
auto nq = queries.getSize(0); \
auto dim = queries.getSize(1); \
auto nprobe = listIds.getSize(1); \
\
DeviceTensor<float, 3, true> distanceTemp( \
res, makeTempAlloc(AllocType::Other, stream), \
{queries.getSize(0), listIds.getSize(1), k}); \
DeviceTensor<int, 3, true> indicesTemp( \
res, makeTempAlloc(AllocType::Other, stream), \
{queries.getSize(0), listIds.getSize(1), k}); \
\
if (metric == MetricType::METRIC_L2) { \
L2Distance metric; \
IVFINT_CODECS(L2Distance, THREADS, NUM_WARP_Q, NUM_THREAD_Q); \
} else if (metric == MetricType::METRIC_INNER_PRODUCT) { \
IPDistance metric; \
IVFINT_CODECS(IPDistance, THREADS, NUM_WARP_Q, NUM_THREAD_Q); \
} else { \
FAISS_ASSERT(false); \
} \
} while (0)
// Top-level IVF scan function for the interleaved by 32 layout
// with all implementations
void runIVFInterleavedScan(Tensor<float, 2, true>& queries,
Tensor<int, 2, true>& listIds,
thrust::device_vector<void*>& listData,
thrust::device_vector<void*>& listIndices,
IndicesOptions indicesOptions,
thrust::device_vector<int>& listLengths,
int k,
faiss::MetricType metric,
bool useResidual,
Tensor<float, 3, true>& residualBase,
GpuScalarQuantizer* scalarQ,
// output
Tensor<float, 2, true>& outDistances,
// output
Tensor<Index::idx_t, 2, true>& outIndices,
GpuResources* res);
// Second pass of IVF list scanning to perform final k-selection and look up the
// user indices
void runIVFInterleavedScan2(Tensor<float, 3, true>& distanceIn,
Tensor<int, 3, true>& indicesIn,
Tensor<int, 2, true>& listIds,
int k,
thrust::device_vector<void*>& listIndices,
IndicesOptions indicesOptions,
bool dir,
Tensor<float, 2, true>& distanceOut,
Tensor<Index::idx_t, 2, true>& indicesOut,
cudaStream_t stream);
} } // namespace

View File

@ -11,6 +11,7 @@
#include <faiss/gpu/impl/BroadcastSum.cuh>
#include <faiss/gpu/impl/Distance.cuh>
#include <faiss/gpu/impl/FlatIndex.cuh>
#include <faiss/gpu/impl/InterleavedCodes.h>
#include <faiss/gpu/impl/IVFAppend.cuh>
#include <faiss/gpu/impl/L2Norm.cuh>
#include <faiss/gpu/impl/PQCodeDistances.cuh>
@ -40,7 +41,7 @@ IVFPQ::IVFPQ(GpuResources* resources,
int bitsPerSubQuantizer,
bool useFloat16LookupTables,
bool useMMCodeDistance,
bool alternativeLayout,
bool interleavedLayout,
float* pqCentroidData,
IndicesOptions indicesOptions,
MemorySpace space) :
@ -48,6 +49,7 @@ IVFPQ::IVFPQ(GpuResources* resources,
metric,
metricArg,
quantizer,
interleavedLayout,
indicesOptions,
space),
numSubQuantizers_(numSubQuantizers),
@ -56,7 +58,6 @@ IVFPQ::IVFPQ(GpuResources* resources,
dimPerSubQuantizer_(dim_ / numSubQuantizers),
useFloat16LookupTables_(useFloat16LookupTables),
useMMCodeDistance_(useMMCodeDistance),
alternativeLayout_(alternativeLayout),
precomputedCodes_(false) {
FAISS_ASSERT(pqCentroidData);
@ -119,6 +120,10 @@ IVFPQ::setPrecomputedCodes(bool enable) {
void
IVFPQ::appendVectors_(Tensor<float, 2, true>& vecs,
Tensor<Index::idx_t, 1, true>& indices,
Tensor<int, 1, true>& uniqueLists,
Tensor<int, 1, true>& vectorsByUniqueList,
Tensor<int, 1, true>& uniqueListVectorStart,
Tensor<int, 1, true>& uniqueListStartOffset,
Tensor<int, 1, true>& listIds,
Tensor<int, 1, true>& listOffset,
cudaStream_t stream) {
@ -209,7 +214,7 @@ IVFPQ::appendVectors_(Tensor<float, 2, true>& vecs,
listOffset,
encodings,
indices,
alternativeLayout_,
interleavedLayout_,
deviceListDataPointers_,
deviceListIndexPointers_,
indicesOptions_,
@ -218,7 +223,7 @@ IVFPQ::appendVectors_(Tensor<float, 2, true>& vecs,
size_t
IVFPQ::getGpuVectorsEncodingSize_(int numVecs) const {
if (alternativeLayout_) {
if (interleavedLayout_) {
return utils::roundUp(
(size_t) numVecs, (size_t) 32) * numSubQuantizers_;
} else {
@ -235,52 +240,24 @@ IVFPQ::getCpuVectorsEncodingSize_(int numVecs) const {
std::vector<uint8_t>
IVFPQ::translateCodesToGpu_(std::vector<uint8_t> codes,
size_t numVecs) const {
if (!alternativeLayout_) {
if (!interleavedLayout_) {
return codes;
}
auto totalSize = getGpuVectorsEncodingSize_(numVecs);
std::vector<uint8_t> out(totalSize);
for (int i = 0; i < numVecs; ++i) {
for (int j = 0; j < numSubQuantizers_; ++j) {
int block = (i / 32) * numSubQuantizers_ + j;
int withinBlock = i % 32;
size_t srcOffset = (size_t) i * numSubQuantizers_ + j;
size_t dstOffset = (size_t) block * 32 + withinBlock;
out[dstOffset] = codes[srcOffset];
}
}
return out;
auto up = unpackNonInterleaved(std::move(codes), numVecs, numSubQuantizers_, 8);
return packInterleaved(std::move(up), numVecs, numSubQuantizers_, 8);
}
// Conver the GPU layout to the CPU layout
std::vector<uint8_t>
IVFPQ::translateCodesFromGpu_(std::vector<uint8_t> codes,
size_t numVecs) const {
if (!alternativeLayout_) {
if (!interleavedLayout_) {
return codes;
}
auto totalSize = getCpuVectorsEncodingSize_(numVecs);
std::vector<uint8_t> out(totalSize);
for (int i = 0; i < numVecs; ++i) {
for (int j = 0; j < numSubQuantizers_; ++j) {
int block = (i / 32) * numSubQuantizers_ + j;
int withinBlock = i % 32;
size_t srcOffset = (size_t) block * 32 + withinBlock;
size_t dstOffset = (size_t) i * numSubQuantizers_ + j;
out[dstOffset] = codes[srcOffset];
}
}
return out;
auto up = unpackInterleaved(std::move(codes), numVecs, numSubQuantizers_, 8);
return packNonInterleaved(std::move(up), numVecs, numSubQuantizers_, 8);
}
void
@ -608,7 +585,7 @@ IVFPQ::runPQPrecomputedCodes_(
term3, // term 3
coarseIndices,
useFloat16LookupTables_,
alternativeLayout_,
interleavedLayout_,
numSubQuantizers_,
numSubQuantizerCodes_,
deviceListDataPointers_,
@ -640,7 +617,7 @@ IVFPQ::runPQNoPrecomputedCodesT_(
coarseIndices,
useFloat16LookupTables_,
useMMCodeDistance_,
alternativeLayout_,
interleavedLayout_,
numSubQuantizers_,
numSubQuantizerCodes_,
deviceListDataPointers_,

View File

@ -27,7 +27,7 @@ class IVFPQ : public IVFBase {
int bitsPerSubQuantizer,
bool useFloat16LookupTables,
bool useMMCodeDistance,
bool alternativeLayout,
bool interleavedLayout,
float* pqCentroidData,
IndicesOptions indicesOptions,
MemorySpace space);
@ -68,6 +68,10 @@ class IVFPQ : public IVFBase {
/// Encode the vectors that we're adding and append to our IVF lists
void appendVectors_(Tensor<float, 2, true>& vecs,
Tensor<Index::idx_t, 1, true>& indices,
Tensor<int, 1, true>& uniqueLists,
Tensor<int, 1, true>& vectorsByUniqueList,
Tensor<int, 1, true>& uniqueListVectorStart,
Tensor<int, 1, true>& uniqueListStartOffset,
Tensor<int, 1, true>& listIds,
Tensor<int, 1, true>& listOffset,
cudaStream_t stream) override;
@ -135,16 +139,6 @@ class IVFPQ : public IVFBase {
/// purposes.
const bool useMMCodeDistance_;
/// The default memory layout is [vector][PQ component]:
/// (v0 d0) (v0 d1) ... (v0 dD-1) (v1 d0) (v1 d1) ...
///
/// An alternative memory layout (layoutBy32) is
/// [vector / 32][PQ component][vector % 32] with padding:
/// (v0 d0) (v1 d0) ... (v31 d0) (v0 d1) (v1 d1) ... (v31 dD-1) (v32 d0) (v33
/// d0) ...
/// so the list length is always a multiple of numSubQuantizers * 32
const bool alternativeLayout_;
/// On the GPU, we prefer different PQ centroid data layouts for
/// different purposes.
///

View File

@ -0,0 +1,410 @@
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <faiss/gpu/impl/InterleavedCodes.h>
#include <faiss/impl/FaissAssert.h>
#include <faiss/gpu/utils/StaticUtils.h>
namespace faiss { namespace gpu {
std::vector<uint8_t>
unpackNonInterleaved(std::vector<uint8_t> data,
int numVecs,
int dims,
int bitsPerCode) {
int srcVecSize = utils::divUp(dims * bitsPerCode, 8);
FAISS_ASSERT(data.size() == numVecs * srcVecSize);
if (bitsPerCode == 8 ||
bitsPerCode == 16 ||
bitsPerCode == 32) {
// nothing to do
return data;
}
// bit codes padded to whole bytes
std::vector<uint8_t> out(numVecs * dims * utils::divUp(bitsPerCode, 8));
if (bitsPerCode == 6) {
#pragma omp parallel for
for (int i = 0; i < numVecs; ++i) {
for (int j = 0; j < dims; ++j) {
int lo = i * srcVecSize + (j * 6) / 8;
int hi = lo + 1;
FAISS_ASSERT(lo < data.size());
FAISS_ASSERT(hi <= data.size());
auto vLower = data[lo];
auto vUpper = hi < data.size() ? data[hi] : 0;
uint8_t v = 0;
switch (j % 4) {
case 0:
// 6 lsbs of lower
v = vLower & 0x3f;
break;
case 1:
// 2 msbs of lower as v lsbs
// 4 lsbs of upper as v msbs
v = (vLower >> 6) | ((vUpper & 0xf) << 2);
break;
case 2:
// 4 msbs of lower as v lsbs
// 2 lsbs of upper as v msbs
v = (vLower >> 4) | ((vUpper & 0x3) << 4);
break;
case 3:
// 6 msbs of lower
v = (vLower >> 2);
break;
}
out[i * dims + j] = v;
}
}
} else if (bitsPerCode == 4) {
#pragma omp parallel for
for (int i = 0; i < numVecs; ++i) {
for (int j = 0; j < dims; ++j) {
int srcIdx = i * srcVecSize + (j / 2);
FAISS_ASSERT(srcIdx < data.size());
uint8_t v = data[srcIdx];
v = (j % 2 == 0) ? v & 0xf : v >> 4;
out[i * dims + j] = v;
}
}
} else {
// unhandled
FAISS_ASSERT(false);
}
return out;
}
template <typename T>
void
unpackInterleavedWord(const T* in,
T* out,
int numVecs,
int dims,
int bitsPerCode) {
int wordsPerDimBlock = 32 * bitsPerCode / (8 * sizeof(T));
int wordsPerBlock = wordsPerDimBlock * dims;
int numBlocks = utils::divUp(numVecs, 32);
#pragma omp parallel for
for (int i = 0; i < numVecs; ++i) {
int block = i / 32;
FAISS_ASSERT(block < numBlocks);
int lane = i % 32;
for (int j = 0; j < dims; ++j) {
int srcOffset = block * wordsPerBlock + j * wordsPerDimBlock + lane;
out[i * dims + j] = in[srcOffset];
}
}
}
std::vector<uint8_t>
unpackInterleaved(std::vector<uint8_t> data,
int numVecs,
int dims,
int bitsPerCode) {
int bytesPerDimBlock = 32 * bitsPerCode / 8;
int bytesPerBlock = bytesPerDimBlock * dims;
int numBlocks = utils::divUp(numVecs, 32);
size_t totalSize = (size_t) bytesPerBlock * numBlocks;
FAISS_ASSERT(data.size() == totalSize);
// bit codes padded to whole bytes
std::vector<uint8_t> out(numVecs * dims * utils::divUp(bitsPerCode, 8));
if (bitsPerCode == 8) {
unpackInterleavedWord<uint8_t>(data.data(), out.data(),
numVecs, dims, bitsPerCode);
} else if (bitsPerCode == 16) {
unpackInterleavedWord<uint16_t>((uint16_t*) data.data(),
(uint16_t*) out.data(),
numVecs, dims, bitsPerCode);
} else if (bitsPerCode == 32) {
unpackInterleavedWord<uint32_t>((uint32_t*) data.data(),
(uint32_t*) out.data(),
numVecs, dims, bitsPerCode);
} else if (bitsPerCode == 4) {
#pragma omp parallel for
for (int i = 0; i < numVecs; ++i) {
int block = i / 32;
int lane = i % 32;
int word = lane / 2;
int subWord = lane % 2;
for (int j = 0; j < dims; ++j) {
auto v =
data[block * bytesPerBlock + j * bytesPerDimBlock + word];
v = (subWord == 0) ? v & 0xf : v >> 4;
out[i * dims + j] = v;
}
}
} else if (bitsPerCode == 6) {
#pragma omp parallel for
for (int i = 0; i < numVecs; ++i) {
int block = i / 32;
int blockVector = i % 32;
for (int j = 0; j < dims; ++j) {
uint8_t* dimBlock =
&data[block * bytesPerBlock + j * bytesPerDimBlock];
int lo = (blockVector * 6) / 8;
int hi = lo + 1;
FAISS_ASSERT(lo < bytesPerDimBlock);
FAISS_ASSERT(hi <= bytesPerDimBlock);
auto vLower = dimBlock[lo];
auto vUpper = hi < bytesPerDimBlock ? dimBlock[hi] : 0;
uint8_t v = 0;
switch (blockVector % 4) {
case 0:
// 6 lsbs of lower
v = vLower & 0x3f;
break;
case 1:
// 2 msbs of lower as v lsbs
// 4 lsbs of upper as v msbs
v = (vLower >> 6) | ((vUpper & 0xf) << 2);
break;
case 2:
// 4 msbs of lower as v lsbs
// 2 lsbs of upper as v msbs
v = (vLower >> 4) | ((vUpper & 0x3) << 4);
break;
case 3:
// 6 msbs of lower
v = (vLower >> 2);
break;
}
out[i * dims + j] = v;
}
}
} else {
// unimplemented
FAISS_ASSERT(false);
}
return out;
}
std::vector<uint8_t>
packNonInterleaved(std::vector<uint8_t> data,
int numVecs,
int dims,
int bitsPerCode) {
// bit codes padded to whole bytes
FAISS_ASSERT(data.size() == numVecs * dims * utils::divUp(bitsPerCode, 8));
if (bitsPerCode == 8 ||
bitsPerCode == 16 ||
bitsPerCode == 32) {
// nothing to do, whole words are already where they need to be
return data;
}
// bits packed into a whole number of bytes
int bytesPerVec = utils::divUp(dims * bitsPerCode, 8);
std::vector<uint8_t> out(numVecs * bytesPerVec);
if (bitsPerCode == 4) {
#pragma omp parallel for
for (int i = 0; i < numVecs; ++i) {
for (int j = 0; j < bytesPerVec; ++j) {
int dimLo = j * 2;
int dimHi = dimLo + 1;
FAISS_ASSERT(dimLo < dims);
FAISS_ASSERT(dimHi <= dims);
uint8_t lo = data[i * dims + dimLo];
uint8_t hi = dimHi < dims ? data[i * dims + dimHi] : 0;
out[i * bytesPerVec + j] = (hi << 4) | (lo & 0xf);
}
}
} else if (bitsPerCode == 6) {
#pragma omp parallel for
for (int i = 0; i < numVecs; ++i) {
for (int j = 0; j < bytesPerVec; ++j) {
int dimLo = (j * 8) / 6;
int dimHi = dimLo + 1;
FAISS_ASSERT(dimLo < dims);
FAISS_ASSERT(dimHi <= dims);
uint8_t lo = data[i * dims + dimLo];
uint8_t hi = dimHi < dims ? data[i * dims + dimHi] : 0;
uint8_t v = 0;
// lsb ... msb
// 0: 0 0 0 0 0 0 1 1
// 1: 1 1 1 1 2 2 2 2
// 2: 2 2 3 3 3 3 3 3
switch (j % 3) {
case 0:
// 6 msbs of lower as vOut lsbs
// 2 lsbs of upper as vOut msbs
v = (lo & 0x3f) | (hi << 6);
break;
case 1:
// 4 msbs of lower as vOut lsbs
// 4 lsbs of upper as vOut msbs
v = (lo >> 2) | (hi << 4);
break;
case 2:
// 2 msbs of lower as vOut lsbs
// 6 lsbs of upper as vOut msbs
v = (lo >> 4) | (hi << 2);
break;
}
out[i * bytesPerVec + j] = v;
}
}
} else {
// unhandled
FAISS_ASSERT(false);
}
return out;
}
template <typename T>
void
packInterleavedWord(const T* in,
T* out,
int numVecs,
int dims,
int bitsPerCode) {
int wordsPerDimBlock = 32 * bitsPerCode / (8 * sizeof(T));
int wordsPerBlock = wordsPerDimBlock * dims;
int numBlocks = utils::divUp(numVecs, 32);
// We're guaranteed that all other slots not filled by the vectors present are
// initialized to zero (from the vector constructor in packInterleaved)
#pragma omp parallel for
for (int i = 0; i < numVecs; ++i) {
int block = i / 32;
FAISS_ASSERT(block < numBlocks);
int lane = i % 32;
for (int j = 0; j < dims; ++j) {
int dstOffset = block * wordsPerBlock + j * wordsPerDimBlock + lane;
out[dstOffset] = in[i * dims + j];
}
}
}
std::vector<uint8_t>
packInterleaved(std::vector<uint8_t> data,
int numVecs,
int dims,
int bitsPerCode) {
int bytesPerDimBlock = 32 * bitsPerCode / 8;
int bytesPerBlock = bytesPerDimBlock * dims;
int numBlocks = utils::divUp(numVecs, 32);
size_t totalSize = (size_t) bytesPerBlock * numBlocks;
// bit codes padded to whole bytes
FAISS_ASSERT(data.size() == numVecs * dims * utils::divUp(bitsPerCode, 8));
// packs based on blocks
std::vector<uint8_t> out(totalSize, 0);
if (bitsPerCode == 8) {
packInterleavedWord<uint8_t>(data.data(), out.data(),
numVecs, dims, bitsPerCode);
} else if (bitsPerCode == 16) {
packInterleavedWord<uint16_t>((uint16_t*) data.data(),
(uint16_t*) out.data(),
numVecs, dims, bitsPerCode);
} else if (bitsPerCode == 32) {
packInterleavedWord<uint32_t>((uint32_t*) data.data(),
(uint32_t*) out.data(),
numVecs, dims, bitsPerCode);
} else if (bitsPerCode == 4) {
#pragma omp parallel for
for (int i = 0; i < numBlocks; ++i) {
for (int j = 0; j < dims; ++j) {
for (int k = 0; k < bytesPerDimBlock; ++k) {
int loVec = i * 32 + k * 2;
int hiVec = loVec + 1;
uint8_t lo = loVec < numVecs ? data[loVec * dims + j] : 0;
uint8_t hi = hiVec < numVecs ? data[hiVec * dims + j] : 0;
out[i * bytesPerBlock + j * bytesPerDimBlock + k] =
(hi << 4) | (lo & 0xf);
}
}
}
} else if (bitsPerCode == 6) {
#pragma omp parallel for
for (int i = 0; i < numBlocks; ++i) {
for (int j = 0; j < dims; ++j) {
for (int k = 0; k < bytesPerDimBlock; ++k) {
// What input vectors we are pulling from
int loVec = i * 32 + (k * 8) / 6;
int hiVec = loVec + 1;
uint8_t lo = loVec < numVecs ? data[loVec * dims + j] : 0;
uint8_t hi = hiVec < numVecs ? data[hiVec * dims + j] : 0;
uint8_t v = 0;
// lsb ... msb
// 0: 0 0 0 0 0 0 1 1
// 1: 1 1 1 1 2 2 2 2
// 2: 2 2 3 3 3 3 3 3
switch (k % 3) {
case 0:
// 6 msbs of lower as vOut lsbs
// 2 lsbs of upper as vOut msbs
v = (lo & 0x3f) | (hi << 6);
break;
case 1:
// 4 msbs of lower as vOut lsbs
// 4 lsbs of upper as vOut msbs
v = (lo >> 2) | (hi << 4);
break;
case 2:
// 2 msbs of lower as vOut lsbs
// 6 lsbs of upper as vOut msbs
v = (lo >> 4) | (hi << 2);
break;
}
out[i * bytesPerBlock + j * bytesPerDimBlock + k] = v;
}
}
}
} else {
// unimplemented
FAISS_ASSERT(false);
}
return out;
}
} } // namespace

View File

@ -0,0 +1,51 @@
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
#pragma once
#include <stdint.h>
#include <vector>
// Utilities for bit packing and unpacking CPU non-interleaved and GPU
// interleaved by 32 encodings
namespace faiss { namespace gpu {
// Unpacks arbitrary bitwidth codes to a whole number of bytes per code
// The layout of the input is (v0 d0)(v0 d1) ... (v0 dD)(v1 d0) ...
// (bit packed)
// The layout of the output is the same (byte packed to roundUp(bitsPerCode, 8)
// / 8 bytes)
std::vector<uint8_t> unpackNonInterleaved(std::vector<uint8_t> data,
int numVecs,
int dims,
int bitsPerCode);
// Unpacks arbitrary bitwidth codes to a whole number of bytes per scalar code
// The layout of the input is (v0 d0)(v1 d0) ... (v31 d0)(v0 d1) ...
// (bit packed)
// The layout of the input is (v0 d0)(v0 d1) ... (v0 dD)(v1 d0) ...
// (byte packed)
std::vector<uint8_t> unpackInterleaved(std::vector<uint8_t> data,
int numVecs,
int dims,
int bitsPerCode);
// Packs data in the byte packed non-interleaved form to bit packed
// non-interleaved form
std::vector<uint8_t> packNonInterleaved(std::vector<uint8_t> data,
int numVecs,
int dims,
int bitsPerCode);
// Packs data in the byte packed non-interleaved form to bit packed
// interleaved form
std::vector<uint8_t> packInterleaved(std::vector<uint8_t> data,
int numVecs,
int dims,
int bitsPerCode);
} } // namespace

View File

@ -0,0 +1,15 @@
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <faiss/gpu/impl/scan/IVFInterleavedImpl.cuh>
namespace faiss { namespace gpu {
IVF_INTERLEAVED_IMPL(128, 1, 1)
} } // namespace

View File

@ -0,0 +1,15 @@
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <faiss/gpu/impl/scan/IVFInterleavedImpl.cuh>
namespace faiss { namespace gpu {
IVF_INTERLEAVED_IMPL(128, 1024, 8)
} } // namespace

View File

@ -0,0 +1,15 @@
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <faiss/gpu/impl/scan/IVFInterleavedImpl.cuh>
namespace faiss { namespace gpu {
IVF_INTERLEAVED_IMPL(128, 128, 3)
} } // namespace

View File

@ -0,0 +1,17 @@
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <faiss/gpu/impl/scan/IVFInterleavedImpl.cuh>
namespace faiss { namespace gpu {
#if GPU_MAX_SELECTION_K >= 2048
IVF_INTERLEAVED_IMPL(64, 2048, 8)
#endif
} } // namespace

View File

@ -0,0 +1,15 @@
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <faiss/gpu/impl/scan/IVFInterleavedImpl.cuh>
namespace faiss { namespace gpu {
IVF_INTERLEAVED_IMPL(128, 256, 4)
} } // namespace

View File

@ -0,0 +1,15 @@
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <faiss/gpu/impl/scan/IVFInterleavedImpl.cuh>
namespace faiss { namespace gpu {
IVF_INTERLEAVED_IMPL(128, 32, 2)
} } // namespace

View File

@ -0,0 +1,15 @@
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <faiss/gpu/impl/scan/IVFInterleavedImpl.cuh>
namespace faiss { namespace gpu {
IVF_INTERLEAVED_IMPL(128, 512, 8)
} } // namespace

View File

@ -0,0 +1,15 @@
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <faiss/gpu/impl/scan/IVFInterleavedImpl.cuh>
namespace faiss { namespace gpu {
IVF_INTERLEAVED_IMPL(128, 64, 3)
} } // namespace

View File

@ -0,0 +1,86 @@
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
#pragma once
#include <faiss/gpu/impl/IVFInterleaved.cuh>
#include <faiss/gpu/utils/DeviceDefs.cuh>
#define IVF_INTERLEAVED_IMPL(THREADS, WARP_Q, THREAD_Q) \
\
void ivfInterleavedScanImpl_ ## WARP_Q ## _( \
Tensor<float, 2, true>& queries, \
Tensor<int, 2, true>& listIds, \
thrust::device_vector<void*>& listData, \
thrust::device_vector<void*>& listIndices, \
IndicesOptions indicesOptions, \
thrust::device_vector<int>& listLengths, \
int k, \
faiss::MetricType metric, \
bool useResidual, \
Tensor<float, 3, true>& residualBase, \
GpuScalarQuantizer* scalarQ, \
Tensor<float, 2, true>& outDistances, \
Tensor<Index::idx_t, 2, true>& outIndices, \
GpuResources* res) { \
FAISS_ASSERT(k <= WARP_Q); \
\
IVFINT_METRICS(THREADS, WARP_Q, THREAD_Q); \
\
CUDA_TEST_ERROR(); \
}
#define IVF_INTERLEAVED_DECL(WARP_Q) \
\
void ivfInterleavedScanImpl_ ## WARP_Q ## _( \
Tensor<float, 2, true>& queries, \
Tensor<int, 2, true>& listIds, \
thrust::device_vector<void*>& listData, \
thrust::device_vector<void*>& listIndices, \
IndicesOptions indicesOptions, \
thrust::device_vector<int>& listLengths, \
int k, \
faiss::MetricType metric, \
bool useResidual, \
Tensor<float, 3, true>& residualBase, \
GpuScalarQuantizer* scalarQ, \
Tensor<float, 2, true>& outDistances, \
Tensor<Index::idx_t, 2, true>& outIndices, \
GpuResources* res)
#define IVF_INTERLEAVED_CALL(WARP_Q) \
ivfInterleavedScanImpl_ ## WARP_Q ## _( \
queries, \
listIds, \
listData, \
listIndices, \
indicesOptions, \
listLengths, \
k, \
metric, \
useResidual, \
residualBase, \
scalarQ, \
outDistances, \
outIndices, \
res)
namespace faiss { namespace gpu {
IVF_INTERLEAVED_DECL(1);
IVF_INTERLEAVED_DECL(32);
IVF_INTERLEAVED_DECL(64);
IVF_INTERLEAVED_DECL(128);
IVF_INTERLEAVED_DECL(256);
IVF_INTERLEAVED_DECL(512);
IVF_INTERLEAVED_DECL(1024);
#if GPU_MAX_SELECTION_K >= 2048
IVF_INTERLEAVED_DECL(2048);
#endif
} } // namespace

View File

@ -19,6 +19,7 @@ macro(faiss_gpu_test file)
gtest_discover_tests(${test_name})
endmacro()
faiss_gpu_test(TestCodePacking.cpp)
faiss_gpu_test(TestGpuIndexFlat.cpp)
faiss_gpu_test(TestGpuIndexIVFFlat.cpp)
faiss_gpu_test(TestGpuIndexBinaryFlat.cpp)

View File

@ -0,0 +1,241 @@
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <faiss/gpu/impl/InterleavedCodes.h>
#include <faiss/gpu/utils/StaticUtils.h>
#include <faiss/gpu/test/TestUtils.h>
#include <cmath>
#include <gtest/gtest.h>
#include <random>
#include <sstream>
#include <vector>
TEST(TestCodePacking, NonInterleavedCodes_UnpackPack) {
using namespace faiss::gpu;
// We are fine using non-fixed seeds here, the results should be fully
// deterministic
auto seed = std::random_device()();
std::mt19937 gen(seed);
std::uniform_int_distribution<uint8_t> dist;
std::cout << "seed " << seed << "\n";
for (auto bitsPerCode : {4, 6, 8, 16, 32}) {
for (auto dims : {1, 7, 8, 31, 32}) {
for (auto numVecs : {1, 3, 4, 5, 6, 8, 31, 32, 33, 65}) {
std::cout << bitsPerCode << " " << dims << " " << numVecs << "\n";
int srcVecSize = utils::divUp(dims * bitsPerCode, 8);
std::vector<uint8_t> data(numVecs * srcVecSize);
for (auto& v : data) {
v = dist(gen);
}
// currently unimplemented
EXPECT_FALSE(bitsPerCode > 8 && bitsPerCode % 8 != 0);
// Due to bit packing, mask out bits that should be zero based on
// dimensions we shouldn't have present
int vectorSizeBits = dims * bitsPerCode;
int vectorSizeBytes = utils::divUp(vectorSizeBits, 8);
int remainder = vectorSizeBits % 8;
if (remainder > 0) {
uint8_t mask = 0xff >> (8 - remainder);
for (int i = 0; i < numVecs; ++i) {
int lastVecByte = (i + 1) * vectorSizeBytes - 1;
data[lastVecByte] &= mask;
}
}
auto up = unpackNonInterleaved(data, numVecs, dims, bitsPerCode);
auto p = packNonInterleaved(up, numVecs, dims, bitsPerCode);
EXPECT_EQ(data, p);
}
}
}
}
TEST(TestCodePacking, NonInterleavedCodes_PackUnpack) {
using namespace faiss::gpu;
// We are fine using non-fixed seeds here, the results should be fully
// deterministic
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_int_distribution<uint8_t> dist;
for (auto bitsPerCode : {4, 6, 8, 16, 32}) {
for (auto dims : {1, 7, 8, 31, 32}) {
for (auto numVecs : {1, 3, 4, 5, 6, 8, 31, 32, 33, 65}) {
std::cout << bitsPerCode << " " << dims << " " << numVecs << "\n";
std::vector<uint8_t> data(numVecs * dims * utils::divUp(bitsPerCode, 8));
// currently unimplemented
EXPECT_FALSE(bitsPerCode > 8 && bitsPerCode % 8 != 0);
// Mask out high bits we shouldn't have based on code size
uint8_t mask = bitsPerCode < 8 ? (0xff >> (8 - bitsPerCode)) : 0xff;
for (auto& v : data) {
v = dist(gen) & mask;
}
auto p = packNonInterleaved(data, numVecs, dims, bitsPerCode);
auto up = unpackNonInterleaved(p, numVecs, dims, bitsPerCode);
EXPECT_EQ(data, up);
}
}
}
}
TEST(TestCodePacking, InterleavedCodes_UnpackPack) {
using namespace faiss::gpu;
// We are fine using non-fixed seeds here, the results should be fully
// deterministic
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_int_distribution<uint8_t> dist;
for (auto bitsPerCode : {4, 6, 8, 16, 32}) {
for (auto dims : {1, 7, 8, 31, 32}) {
for (auto numVecs : {1, 3, 4, 5, 6, 8, 31, 32, 33, 65}) {
std::cout << bitsPerCode << " " << dims << " " << numVecs << "\n";
int blocks = utils::divUp(numVecs, 32);
int bytesPerDimBlock = 32 * bitsPerCode / 8;
int bytesPerBlock = bytesPerDimBlock * dims;
int size = blocks * bytesPerBlock;
std::vector<uint8_t> data(size);
if (bitsPerCode == 8 || bitsPerCode == 16 || bitsPerCode == 32) {
int bytesPerCode = bitsPerCode / 8;
for (int i = 0; i < blocks; ++i) {
for (int j = 0; j < dims; ++j) {
for (int k = 0; k < 32; ++k) {
for (int l = 0; l < bytesPerCode; ++l) {
int vec = i * 32 + k;
if (vec < numVecs) {
data[i * bytesPerBlock +
j * bytesPerDimBlock +
k * bytesPerCode + l] = dist(gen);
}
}
}
}
}
} else if (bitsPerCode < 8) {
for (int i = 0; i < blocks; ++i) {
for (int j = 0; j < dims; ++j) {
for (int k = 0; k < bytesPerDimBlock; ++k) {
int loVec = i * 32 + (k * 8) / bitsPerCode;
int hiVec = loVec + 1;
uint8_t lo = loVec < numVecs ?
dist(gen) & (0xff >> (8 - bitsPerCode)) : 0;
uint8_t hi = hiVec < numVecs ?
dist(gen) & (0xff >> (8 - bitsPerCode)) : 0;
uint8_t v = 0;
if (bitsPerCode == 4) {
v = lo | (hi << 4);
} else if (bitsPerCode == 6) {
switch (k % 3) {
case 0:
// 6 msbs of lower as vOut lsbs
// 2 lsbs of upper as vOut msbs
v = (lo & 0x3f) | (hi << 6);
break;
case 1:
// 4 msbs of lower as vOut lsbs
// 4 lsbs of upper as vOut msbs
v = (lo >> 2) | (hi << 4);
break;
case 2:
// 2 msbs of lower as vOut lsbs
// 6 lsbs of upper as vOut msbs
v = (lo >> 4) | (hi << 2);
break;
}
} else {
// unimplemented
EXPECT_TRUE(false);
}
data[i * bytesPerBlock + j * bytesPerDimBlock + k] = v;
}
}
}
} else {
// unimplemented
EXPECT_TRUE(false);
}
auto up = unpackInterleaved(data, numVecs, dims, bitsPerCode);
auto p = packInterleaved(up, numVecs, dims, bitsPerCode);
EXPECT_EQ(data, p);
}
}
}
}
TEST(TestCodePacking, InterleavedCodes_PackUnpack) {
using namespace faiss::gpu;
// We are fine using non-fixed seeds here, the results should be fully
// deterministic
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_int_distribution<uint8_t> dist;
for (auto bitsPerCode : {4, 6, 8, 16, 32}) {
for (auto dims : {1, 7, 8, 31, 32}) {
for (auto numVecs : {1, 3, 4, 5, 6, 8, 31, 32, 33, 65}) {
std::cout << bitsPerCode << " " << dims << " " << numVecs << "\n";
std::vector<uint8_t> data(numVecs * dims * utils::divUp(bitsPerCode, 8));
if (bitsPerCode == 8 || bitsPerCode == 16 || bitsPerCode == 32) {
for (auto& v : data) {
v = dist(gen);
}
} else if (bitsPerCode < 8) {
uint8_t mask = 0xff >> (8 - bitsPerCode);
for (auto& v : data) {
v = dist(gen) & mask;
}
} else {
// unimplemented
EXPECT_TRUE(false);
}
auto p = packInterleaved(data, numVecs, dims, bitsPerCode);
auto up = unpackInterleaved(p, numVecs, dims, bitsPerCode);
EXPECT_EQ(data, up);
}
}
}
}
int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}

View File

@ -470,9 +470,10 @@ TEST(TestGpuIndexIVFFlat, AddNaN) {
std::vector<float> nans(numNans * opt.dim,
std::numeric_limits<float>::quiet_NaN());
// Make one vector valid, which should actually add
// Make one vector valid (not the first vector, in order to test offset
// issues), which should actually add
for (int i = 0; i < opt.dim; ++i) {
nans[i] = 0.0f;
nans[opt.dim + i] = i;
}
std::vector<float> trainVecs = faiss::gpu::randVecs(opt.numTrain, opt.dim);

View File

@ -35,8 +35,8 @@ struct Options {
k = std::min(faiss::gpu::randVal(10, 30), numAdd / 40);
indicesOpt = faiss::gpu::randSelect({
faiss::gpu::INDICES_CPU,
faiss::gpu::INDICES_32_BIT,
faiss::gpu::INDICES_64_BIT});
faiss::gpu::INDICES_32_BIT,
faiss::gpu::INDICES_64_BIT});
device = faiss::gpu::randVal(0, faiss::gpu::getNumDevices() - 1);
}
@ -66,118 +66,159 @@ struct Options {
faiss::gpu::IndicesOptions indicesOpt;
};
TEST(TestGpuIndexIVFScalarQuantizer, CopyTo) {
void runCopyToTest(faiss::ScalarQuantizer::QuantizerType qtype) {
using namespace faiss;
using namespace faiss::gpu;
for (auto qtype : {ScalarQuantizer::QuantizerType::QT_8bit,
ScalarQuantizer::QuantizerType::QT_4bit}) {
Options opt;
std::vector<float> trainVecs = randVecs(opt.numTrain, opt.dim);
std::vector<float> addVecs = randVecs(opt.numAdd, opt.dim);
Options opt;
std::vector<float> trainVecs = randVecs(opt.numTrain, opt.dim);
std::vector<float> addVecs = randVecs(opt.numAdd, opt.dim);
StandardGpuResources res;
res.noTempMemory();
StandardGpuResources res;
res.noTempMemory();
auto config = GpuIndexIVFScalarQuantizerConfig();
config.device = opt.device;
auto config = GpuIndexIVFScalarQuantizerConfig();
config.device = opt.device;
GpuIndexIVFScalarQuantizer gpuIndex(&res,
opt.dim,
opt.numCentroids,
qtype,
METRIC_L2,
true,
config);
gpuIndex.train(opt.numTrain, trainVecs.data());
gpuIndex.add(opt.numAdd, addVecs.data());
gpuIndex.setNumProbes(opt.nprobe);
GpuIndexIVFScalarQuantizer gpuIndex(&res,
opt.dim,
opt.numCentroids,
qtype,
METRIC_L2,
true,
config);
gpuIndex.train(opt.numTrain, trainVecs.data());
gpuIndex.add(opt.numAdd, addVecs.data());
gpuIndex.setNumProbes(opt.nprobe);
// use garbage values to see if we overwrite then
IndexFlatL2 cpuQuantizer(1);
IndexIVFScalarQuantizer cpuIndex(&cpuQuantizer, 1, 1,
ScalarQuantizer::QuantizerType::QT_6bit,
METRIC_L2);
cpuIndex.nprobe = 1;
// use garbage values to see if we overwrite then
IndexFlatL2 cpuQuantizer(1);
IndexIVFScalarQuantizer cpuIndex(&cpuQuantizer, 1, 1,
ScalarQuantizer::QuantizerType::QT_6bit,
METRIC_L2);
cpuIndex.nprobe = 1;
gpuIndex.copyTo(&cpuIndex);
gpuIndex.copyTo(&cpuIndex);
EXPECT_EQ(cpuIndex.ntotal, gpuIndex.ntotal);
EXPECT_EQ(gpuIndex.ntotal, opt.numAdd);
EXPECT_EQ(cpuIndex.ntotal, gpuIndex.ntotal);
EXPECT_EQ(gpuIndex.ntotal, opt.numAdd);
EXPECT_EQ(cpuIndex.d, gpuIndex.d);
EXPECT_EQ(cpuIndex.quantizer->d, gpuIndex.quantizer->d);
EXPECT_EQ(cpuIndex.d, opt.dim);
EXPECT_EQ(cpuIndex.nlist, gpuIndex.getNumLists());
EXPECT_EQ(cpuIndex.nprobe, gpuIndex.getNumProbes());
EXPECT_EQ(cpuIndex.d, gpuIndex.d);
EXPECT_EQ(cpuIndex.quantizer->d, gpuIndex.quantizer->d);
EXPECT_EQ(cpuIndex.d, opt.dim);
EXPECT_EQ(cpuIndex.nlist, gpuIndex.getNumLists());
EXPECT_EQ(cpuIndex.nprobe, gpuIndex.getNumProbes());
testIVFEquality(cpuIndex, gpuIndex);
testIVFEquality(cpuIndex, gpuIndex);
// Query both objects; results should be equivalent
compareIndices(cpuIndex, gpuIndex,
opt.numQuery, opt.dim, opt.k, opt.toString(),
kF32MaxRelErr,
0.1f,
0.015f);
}
// Query both objects; results should be equivalent
compareIndices(cpuIndex, gpuIndex,
opt.numQuery, opt.dim, opt.k, opt.toString(),
kF32MaxRelErr,
0.1f,
0.015f);
}
TEST(TestGpuIndexIVFScalarQuantizer, CopyTo_fp16) {
runCopyToTest(faiss::ScalarQuantizer::QuantizerType::QT_fp16);
}
TEST(TestGpuIndexIVFScalarQuantizer, CopyFrom) {
TEST(TestGpuIndexIVFScalarQuantizer, CopyTo_8bit) {
runCopyToTest(faiss::ScalarQuantizer::QuantizerType::QT_8bit);
}
TEST(TestGpuIndexIVFScalarQuantizer, CopyTo_8bit_uniform) {
runCopyToTest(faiss::ScalarQuantizer::QuantizerType::QT_8bit_uniform);
}
TEST(TestGpuIndexIVFScalarQuantizer, CopyTo_6bit) {
runCopyToTest(faiss::ScalarQuantizer::QuantizerType::QT_6bit);
}
TEST(TestGpuIndexIVFScalarQuantizer, CopyTo_4bit) {
runCopyToTest(faiss::ScalarQuantizer::QuantizerType::QT_4bit);
}
TEST(TestGpuIndexIVFScalarQuantizer, CopyTo_4bit_uniform) {
runCopyToTest(faiss::ScalarQuantizer::QuantizerType::QT_4bit_uniform);
}
void runCopyFromTest(faiss::ScalarQuantizer::QuantizerType qtype) {
using namespace faiss;
using namespace faiss::gpu;
for (auto qtype : {ScalarQuantizer::QuantizerType::QT_8bit,
ScalarQuantizer::QuantizerType::QT_4bit}) {
Options opt;
std::vector<float> trainVecs = randVecs(opt.numTrain, opt.dim);
std::vector<float> addVecs = randVecs(opt.numAdd, opt.dim);
Options opt;
std::vector<float> trainVecs = randVecs(opt.numTrain, opt.dim);
std::vector<float> addVecs = randVecs(opt.numAdd, opt.dim);
IndexFlatL2 cpuQuantizer(opt.dim);
IndexIVFScalarQuantizer cpuIndex(&cpuQuantizer, opt.dim, opt.numCentroids,
qtype,
METRIC_L2);
IndexFlatL2 cpuQuantizer(opt.dim);
IndexIVFScalarQuantizer cpuIndex(&cpuQuantizer, opt.dim, opt.numCentroids,
qtype,
METRIC_L2);
cpuIndex.nprobe = opt.nprobe;
cpuIndex.train(opt.numTrain, trainVecs.data());
cpuIndex.add(opt.numAdd, addVecs.data());
cpuIndex.nprobe = opt.nprobe;
cpuIndex.train(opt.numTrain, trainVecs.data());
cpuIndex.add(opt.numAdd, addVecs.data());
// use garbage values to see if we overwrite then
StandardGpuResources res;
res.noTempMemory();
// use garbage values to see if we overwrite then
StandardGpuResources res;
res.noTempMemory();
auto config = GpuIndexIVFScalarQuantizerConfig();
config.device = opt.device;
auto config = GpuIndexIVFScalarQuantizerConfig();
config.device = opt.device;
GpuIndexIVFScalarQuantizer gpuIndex(
&res,
1,
1,
ScalarQuantizer::QuantizerType::QT_4bit,
METRIC_L2,
false,
config);
gpuIndex.setNumProbes(1);
GpuIndexIVFScalarQuantizer gpuIndex(
&res,
1,
1,
ScalarQuantizer::QuantizerType::QT_4bit,
METRIC_L2,
false,
config);
gpuIndex.setNumProbes(1);
gpuIndex.copyFrom(&cpuIndex);
gpuIndex.copyFrom(&cpuIndex);
EXPECT_EQ(cpuIndex.ntotal, gpuIndex.ntotal);
EXPECT_EQ(gpuIndex.ntotal, opt.numAdd);
EXPECT_EQ(cpuIndex.ntotal, gpuIndex.ntotal);
EXPECT_EQ(gpuIndex.ntotal, opt.numAdd);
EXPECT_EQ(cpuIndex.d, gpuIndex.d);
EXPECT_EQ(cpuIndex.d, opt.dim);
EXPECT_EQ(cpuIndex.nlist, gpuIndex.getNumLists());
EXPECT_EQ(cpuIndex.nprobe, gpuIndex.getNumProbes());
EXPECT_EQ(cpuIndex.d, gpuIndex.d);
EXPECT_EQ(cpuIndex.d, opt.dim);
EXPECT_EQ(cpuIndex.nlist, gpuIndex.getNumLists());
EXPECT_EQ(cpuIndex.nprobe, gpuIndex.getNumProbes());
testIVFEquality(cpuIndex, gpuIndex);
testIVFEquality(cpuIndex, gpuIndex);
// Query both objects; results should be equivalent
compareIndices(cpuIndex, gpuIndex,
opt.numQuery, opt.dim, opt.k, opt.toString(),
kF32MaxRelErr,
0.1f,
0.015f);
}
// Query both objects; results should be equivalent
compareIndices(cpuIndex, gpuIndex,
opt.numQuery, opt.dim, opt.k, opt.toString(),
kF32MaxRelErr,
0.1f,
0.015f);
}
TEST(TestGpuIndexIVFScalarQuantizer, CopyFrom_fp16) {
runCopyFromTest(faiss::ScalarQuantizer::QuantizerType::QT_fp16);
}
TEST(TestGpuIndexIVFScalarQuantizer, CopyFrom_8bit) {
runCopyFromTest(faiss::ScalarQuantizer::QuantizerType::QT_8bit);
}
TEST(TestGpuIndexIVFScalarQuantizer, CopyFrom_8bit_uniform) {
runCopyFromTest(faiss::ScalarQuantizer::QuantizerType::QT_8bit_uniform);
}
TEST(TestGpuIndexIVFScalarQuantizer, CopyFrom_6bit) {
runCopyFromTest(faiss::ScalarQuantizer::QuantizerType::QT_6bit);
}
TEST(TestGpuIndexIVFScalarQuantizer, CopyFrom_4bit) {
runCopyFromTest(faiss::ScalarQuantizer::QuantizerType::QT_4bit);
}
TEST(TestGpuIndexIVFScalarQuantizer, CopyFrom_4bit_uniform) {
runCopyFromTest(faiss::ScalarQuantizer::QuantizerType::QT_4bit_uniform);
}
int main(int argc, char** argv) {

View File

@ -109,7 +109,9 @@ void testIVFEquality(A& cpuIndex, B& gpuIndex) {
auto sc = faiss::InvertedLists::ScopedCodes(cpuLists, i);
std::memcpy(cpuCodes.data(), sc.get(),
cpuLists->list_size(i) * cpuLists->code_size);
EXPECT_EQ(cpuCodes, gpuIndex.getListVectorData(i));
auto gpuCodes = gpuIndex.getListVectorData(i, false);
EXPECT_EQ(cpuCodes, gpuCodes);
// Index equality
std::vector<Index::idx_t> cpuIndices(cpuLists->list_size(i));

View File

@ -364,7 +364,7 @@ class TestInterleavedIVFPQLayout(unittest.TestCase):
nprobe = 4
config = faiss.GpuIndexIVFPQConfig()
config.alternativeLayout = True
config.interleavedLayout = True
idx_gpu = faiss.GpuIndexIVFPQ(res, d, nlist, sub_q, bits_per_code, faiss.METRIC_L2, config)
q = faiss.IndexFlatL2(d)
idx_cpu = faiss.IndexIVFPQ(q, d, nlist, sub_q, bits_per_code, faiss.METRIC_L2)
@ -406,7 +406,7 @@ class TestInterleavedIVFPQLayout(unittest.TestCase):
nprobe = 4
config = faiss.GpuIndexIVFPQConfig()
config.alternativeLayout = True
config.interleavedLayout = True
idx_gpu = faiss.GpuIndexIVFPQ(res, d, nlist, sub_q, bits_per_code, faiss.METRIC_L2, config)
q = faiss.IndexFlatL2(d)
idx_cpu = faiss.IndexIVFPQ(q, d, nlist, sub_q, bits_per_code, faiss.METRIC_L2)
@ -448,7 +448,7 @@ class TestInterleavedIVFPQLayout(unittest.TestCase):
nprobe = 4
config = faiss.GpuIndexIVFPQConfig()
config.alternativeLayout = True
config.interleavedLayout = True
idx_gpu = faiss.GpuIndexIVFPQ(res, d, nlist, sub_q, bits_per_code, faiss.METRIC_L2, config)
q = faiss.IndexFlatL2(d)
idx_cpu = faiss.IndexIVFPQ(q, d, nlist, sub_q, bits_per_code, faiss.METRIC_L2)

View File

@ -180,7 +180,7 @@ def do_multi_test(qtype):
nprobe = 10
k = 50
for d in [11, 64]:
for d in [11, 64, 77]:
if (qtype != faiss.ScalarQuantizer.QT_8bit_direct):
# residual doesn't make sense here
do_test(nlist, d, qtype, True,
@ -205,13 +205,7 @@ class TestSQ(unittest.TestCase):
do_multi_test(faiss.ScalarQuantizer.QT_8bit_uniform)
def test_6bit(self):
try:
do_multi_test(faiss.ScalarQuantizer.QT_6bit)
# should not reach here; QT_6bit is unimplemented
except:
print('QT_6bit exception thrown (is expected)')
else:
assert(False)
do_multi_test(faiss.ScalarQuantizer.QT_6bit)
def test_4bit(self):
do_multi_test(faiss.ScalarQuantizer.QT_4bit)

View File

@ -65,6 +65,26 @@ DeviceTensor<T, Dim, true> toDeviceNonTemporary(
}
}
template <typename T>
DeviceTensor<T, 1, true> toDeviceTemporary(
GpuResources* resources,
const std::vector<T>& src,
cudaStream_t stream,
int device = -1) {
// Uses the current device if device == -1
DeviceScope scope(device);
FAISS_ASSERT(src.size() <
(size_t) std::numeric_limits<int>::max());
DeviceTensor<T, 1, true> out(
resources, makeTempAlloc(AllocType::Other, stream),
{(int) src.size()});
out.copyFrom(src, stream);
return out;
}
/// Copies data to the CPU, if it is not already on the CPU
template <typename T, int Dim>

View File

@ -155,13 +155,18 @@ int getMaxKSelection() {
}
DeviceScope::DeviceScope(int device) {
prevDevice_ = getCurrentDevice();
if (device >= 0) {
int curDevice = getCurrentDevice();
if (prevDevice_ != device) {
setCurrentDevice(device);
} else {
prevDevice_ = -1;
if (curDevice != device) {
prevDevice_ = curDevice;
setCurrentDevice(device);
return;
}
}
// Otherwise, we keep the current device
prevDevice_ = -1;
}
DeviceScope::~DeviceScope() {

View File

@ -22,6 +22,9 @@ namespace faiss { namespace gpu {
/// over whether resize() initializes new space with T() (which we
/// don't want), and control on how much the reserved space grows by
/// upon resize/reserve. It is also meant for POD types only.
///
/// Any new memory allocated is automatically zeroed before being presented to
/// the user.
template <typename T>
class DeviceVector {
public:
@ -157,6 +160,7 @@ class DeviceVector {
FAISS_ASSERT(num_ <= newCapacity);
size_t newSizeInBytes = newCapacity * sizeof(T);
size_t oldSizeInBytes = num_ * sizeof(T);
// The new allocation will occur on this stream
allocInfo_.stream = stream;
@ -164,8 +168,18 @@ class DeviceVector {
auto newAlloc =
res_->allocMemoryHandle(AllocRequest(allocInfo_, newSizeInBytes));
CUDA_VERIFY(cudaMemcpyAsync(newAlloc.data, data(), num_ * sizeof(T),
// Copy over any old data
CUDA_VERIFY(cudaMemcpyAsync(newAlloc.data,
data(),
oldSizeInBytes,
cudaMemcpyDeviceToDevice, stream));
// Zero out the new space past the data we just copied
CUDA_VERIFY(cudaMemsetAsync((uint8_t*) newAlloc.data + oldSizeInBytes,
0,
newSizeInBytes - oldSizeInBytes,
stream));
alloc_ = std::move(newAlloc);
capacity_ = newCapacity;
}

View File

@ -12,6 +12,13 @@
namespace faiss { namespace gpu {
// defines to simplify the SASS assembly structure file/line in the profiler
#define GET_BITFIELD_U32(OUT, VAL, POS, LEN) \
asm("bfe.u32 %0, %1, %2, %3;" : "=r"(OUT) : "r"(VAL), "r"(POS), "r"(LEN));
#define GET_BITFIELD_U64(OUT, VAL, POS, LEN) \
asm("bfe.u64 %0, %1, %2, %3;" : "=l"(OUT) : "l"(VAL), "r"(POS), "r"(LEN));
__device__ __forceinline__
unsigned int getBitfield(unsigned int val, int pos, int len) {
unsigned int ret;

View File

@ -10,6 +10,12 @@
#include <cuda.h>
// allow usage for non-CUDA files
#ifndef __host__
#define __host__
#define __device__
#endif
namespace faiss { namespace gpu { namespace utils {
template <typename U, typename V>

View File

@ -8,6 +8,7 @@
#include <faiss/gpu/GpuFaissAssert.h>
#include <faiss/gpu/utils/DeviceUtils.h>
#include <cstring>
#include <limits>
namespace faiss { namespace gpu {
@ -202,6 +203,58 @@ Tensor<T, Dim, InnerContig, IndexT, PtrTraits>::copyTo(
}
}
template <typename T, int Dim, bool InnerContig,
typename IndexT, template <typename U> class PtrTraits>
__host__ void
Tensor<T, Dim, InnerContig, IndexT, PtrTraits>::copyFrom(
const std::vector<T>& v,
cudaStream_t stream) {
// The tensor must be fully contiguous
GPU_FAISS_ASSERT(this->isContiguous());
// Size must be the same
GPU_FAISS_ASSERT(this->numElements() == v.size());
if (v.size() > 0) {
GPU_FAISS_ASSERT(this->data_);
int ourDev = getDeviceForAddress(this->data_);
CUDA_VERIFY(cudaMemcpyAsync(this->data_,
v.data(),
this->getSizeInBytes(),
ourDev == -1 ? cudaMemcpyHostToHost :
cudaMemcpyHostToDevice,
stream));
}
}
template <typename T, int Dim, bool InnerContig,
typename IndexT, template <typename U> class PtrTraits>
__host__ std::vector<T>
Tensor<T, Dim, InnerContig, IndexT, PtrTraits>::copyToVector(
cudaStream_t stream) {
// The tensor must be fully contiguous
GPU_FAISS_ASSERT(this->isContiguous());
std::vector<T> out(this->numElements());
if (!out.empty()) {
int ourDev = getDeviceForAddress(this->data_);
if (ourDev == -1) {
std::memcpy(out.data(), this->data_, this->numElements() * sizeof(T));
} else {
CUDA_VERIFY(cudaMemcpyAsync(out.data(),
this->data_,
this->numElements() * sizeof(T),
cudaMemcpyDeviceToHost,
stream));
}
}
return out;
}
template <typename T, int Dim, bool InnerContig,
typename IndexT, template <typename U> class PtrTraits>
template <typename OtherT, int OtherDim>

View File

@ -12,6 +12,7 @@
#include <cuda.h>
#include <cuda_runtime.h>
#include <initializer_list>
#include <vector>
/// Multi-dimensional array class for CUDA device and host usage.
/// Originally from Facebook's fbcunn, since added to the Torch GPU
@ -119,6 +120,16 @@ class Tensor {
__host__ void copyTo(Tensor<T, Dim, InnerContig, IndexT, PtrTraits>& t,
cudaStream_t stream);
/// Copies a CPU std::vector<T> into ourselves, allocating memory for it.
/// The total size of our Tensor must match vector<T>::size(), though
/// we are not restricted to 1D Tensors to match the 1D vector<T>.
/// `stream` specifies the stream of the copy and thus the stream on which the
/// memory will initially be used.
__host__ void copyFrom(const std::vector<T>& v, cudaStream_t stream);
/// Copies ourselves into a flattened (1D) std::vector, using the given stream
__host__ std::vector<T> copyToVector(cudaStream_t stream);
/// Returns true if the two tensors are of the same dimensionality,
/// size and stride.
template <typename OtherT, int OtherDim>

View File

@ -0,0 +1,163 @@
/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
#pragma once
#include <faiss/gpu/utils/PtxUtils.cuh>
#include <faiss/gpu/utils/WarpShuffles.cuh>
namespace faiss { namespace gpu {
//
// Warp-coalesced parallel reading and writing of packed bits
//
// Read/write native word sizes
template <typename WordT, int Bits>
struct WarpPackedBits {
static __device__ void write(int laneId, WordT v, bool valid, WordT* out) {
static_assert(sizeof(WordT) == Bits / 8 &&
(Bits % 8) == 0, "");
// We can just write directly
if (valid) {
out[laneId] = v;
}
}
static inline __device__ WordT read(int laneId, WordT* in) {
return in[laneId];
}
static inline __device__ WordT postRead(int laneId, WordT v) {
return v;
}
};
// Read/write 6 bit fields, packed across the warp into 24 bytes
template <>
struct WarpPackedBits<uint8_t, 6> {
static __device__ void write(int laneId, uint8_t v, bool valid, uint8_t* out) {
// Lower 24 lanes wwrite out packed data
int laneFrom = (laneId * 8) / 6;
v = valid ? v : 0;
v &= 0x3f; // ensure we have only 6 bits
uint8_t vLower = (uint8_t) shfl((unsigned int) v, laneFrom);
uint8_t vUpper = (uint8_t) shfl((unsigned int) v, laneFrom + 1);
// lsb ... msb
// 0: 0 0 0 0 0 0 1 1
// 1: 1 1 1 1 2 2 2 2
// 2: 2 2 3 3 3 3 3 3
int typeLane = laneId % 3;
uint8_t vOut = 0;
switch (typeLane) {
case 0:
// 6 msbs of lower as vOut lsbs
// 2 lsbs of upper as vOut msbs
vOut = vLower | (vUpper << 6);
break;
case 1:
// 4 msbs of lower as vOut lsbs
// 4 lsbs of upper as vOut msbs
vOut = (vLower >> 2) | (vUpper << 4);
break;
case 2:
// 2 msbs of lower as vOut lsbs
// 6 lsbs of upper as vOut msbs
vOut = (vLower >> 4) | (vUpper << 2);
break;
}
if (laneId < 24) {
// There could be prior data
out[laneId] |= vOut;
}
}
static inline __device__ uint8_t read(int laneId, uint8_t* in) {
uint8_t v = 0;
if (laneId < 24) {
v = in[laneId];
}
return v;
}
static inline __device__ uint8_t postRead(int laneId, uint8_t v) {
int laneFrom = (laneId * 6) / 8;
// auto vLower = shfl((unsigned int) v, laneFrom);
// auto vUpper = shfl((unsigned int) v, laneFrom + 1);
auto vLower = SHFL_SYNC((unsigned int) v, laneFrom, kWarpSize);
auto vUpper = SHFL_SYNC((unsigned int) v, laneFrom + 1, kWarpSize);
auto vConcat = (vUpper << 8) | vLower;
// Now, this is weird. Each lane reads two uint8, but we wish to use the
// bfe.u32 instruction to read a 6 bit value from the concatenated uint32.
// The offset in which we wish to read in the concatenated word is the
// following:
//
// 0: 0, 1: offset 0 size 6
// 1: 0, 1: offset 6 size 6
// 2: 1, 2: offset 4 size 6
// 3: 2, 3: offset 2 size 6
//
// In binary, each of the offsets are the following (concatenated together):
// 0b0010'0100'0110'0000 or 0x2460
// We can thus use bfe.u32 as a lookup table for the above sequence.
unsigned int pos;
GET_BITFIELD_U32(pos, 0x2460, (laneId & 0x3) * 4, 4);
unsigned int out;
GET_BITFIELD_U32(out, vConcat, pos, 6);
return out;
}
};
// Read/write 4 bit fields, packed across the warp into 16 bytes
template <>
struct WarpPackedBits<uint8_t, 4> {
static __device__ void write(int laneId, uint8_t v, bool valid, uint8_t* out) {
// Lower 16 lanes write out packed data
int laneFrom = laneId * 2;
v = valid ? v : 0;
uint8_t vLower = (uint8_t) shfl((unsigned int) v, laneFrom);
uint8_t vUpper = (uint8_t) shfl((unsigned int) v, laneFrom + 1);
uint8_t vOut = (vLower & 0xf) | (vUpper << 4);
if (laneId < 16) {
// There could be prior data
out[laneId] |= vOut;
}
}
static inline __device__ uint8_t read(int laneId, uint8_t* in) {
uint8_t v = 0;
if (laneId < 16) {
v = in[laneId];
}
return v;
}
static inline __device__ uint8_t postRead(int laneId, uint8_t v) {
int laneFrom = laneId / 2;
auto v2 = shfl((unsigned int) v, laneFrom);
return getBitfield(v2, (laneId & 0x1) * 4, 4);
}
};
} } // namespace

View File

@ -13,6 +13,15 @@
namespace faiss { namespace gpu {
// defines to simplify the SASS assembly structure file/line in the profiler
#if CUDA_VERSION >= 9000
#define SHFL_SYNC(VAL, SRC_LANE, WIDTH) \
__shfl_sync(0xffffffff, VAL, SRC_LANE, WIDTH)
#else
#define SHFL_SYNC(VAL, SRC_LANE, WIDTH) \
__shfl(VAL, SRC_LANE, WIDTH)
#endif
template <typename T>
inline __device__ T shfl(const T val,
int srcLane, int width = kWarpSize) {

View File

@ -1220,33 +1220,41 @@ SQDistanceComputer *select_distance_computer (
ScalarQuantizer::ScalarQuantizer
(size_t d, QuantizerType qtype):
qtype (qtype), rangestat(RS_minmax), rangestat_arg(0), d (d)
qtype (qtype), rangestat(RS_minmax), rangestat_arg(0), d(d)
{
switch (qtype) {
case QT_8bit:
case QT_8bit_uniform:
case QT_8bit_direct:
code_size = d;
break;
case QT_4bit:
case QT_4bit_uniform:
code_size = (d + 1) / 2;
break;
case QT_6bit:
code_size = (d * 6 + 7) / 8;
break;
case QT_fp16:
code_size = d * 2;
break;
}
set_derived_sizes();
}
ScalarQuantizer::ScalarQuantizer ():
qtype(QT_8bit),
rangestat(RS_minmax), rangestat_arg(0), d (0), code_size(0)
rangestat(RS_minmax), rangestat_arg(0), d(0), bits(0), code_size(0)
{}
void ScalarQuantizer::set_derived_sizes ()
{
switch (qtype) {
case QT_8bit:
case QT_8bit_uniform:
case QT_8bit_direct:
code_size = d;
bits = 8;
break;
case QT_4bit:
case QT_4bit_uniform:
code_size = (d + 1) / 2;
bits = 4;
break;
case QT_6bit:
code_size = (d * 6 + 7) / 8;
bits = 6;
break;
case QT_fp16:
code_size = d * 2;
bits = 16;
break;
}
}
void ScalarQuantizer::train (size_t n, const float *x)
{
int bit_per_dim =

View File

@ -53,6 +53,9 @@ struct ScalarQuantizer {
/// dimension of input vectors
size_t d;
/// bits per scalar code
size_t bits;
/// bytes per vector
size_t code_size;
@ -62,6 +65,9 @@ struct ScalarQuantizer {
ScalarQuantizer (size_t d, QuantizerType qtype);
ScalarQuantizer ();
/// updates internal values based on qtype and d
void set_derived_sizes ();
void train (size_t n, const float *x);
/// Used by an IVF index to train based on the residuals

View File

@ -246,6 +246,7 @@ static void read_ScalarQuantizer (ScalarQuantizer *ivsc, IOReader *f) {
READ1 (ivsc->d);
READ1 (ivsc->code_size);
READVECTOR (ivsc->trained);
ivsc->set_derived_sizes ();
}