faiss/faiss/gpu/impl/InterleavedCodes.cpp

411 lines
12 KiB
C++

/**
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <faiss/gpu/impl/InterleavedCodes.h>
#include <faiss/impl/FaissAssert.h>
#include <faiss/gpu/utils/StaticUtils.h>
namespace faiss { namespace gpu {
std::vector<uint8_t>
unpackNonInterleaved(std::vector<uint8_t> data,
int numVecs,
int dims,
int bitsPerCode) {
int srcVecSize = utils::divUp(dims * bitsPerCode, 8);
FAISS_ASSERT(data.size() == numVecs * srcVecSize);
if (bitsPerCode == 8 ||
bitsPerCode == 16 ||
bitsPerCode == 32) {
// nothing to do
return data;
}
// bit codes padded to whole bytes
std::vector<uint8_t> out(numVecs * dims * utils::divUp(bitsPerCode, 8));
if (bitsPerCode == 6) {
#pragma omp parallel for
for (int i = 0; i < numVecs; ++i) {
for (int j = 0; j < dims; ++j) {
int lo = i * srcVecSize + (j * 6) / 8;
int hi = lo + 1;
FAISS_ASSERT(lo < data.size());
FAISS_ASSERT(hi <= data.size());
auto vLower = data[lo];
auto vUpper = hi < data.size() ? data[hi] : 0;
uint8_t v = 0;
switch (j % 4) {
case 0:
// 6 lsbs of lower
v = vLower & 0x3f;
break;
case 1:
// 2 msbs of lower as v lsbs
// 4 lsbs of upper as v msbs
v = (vLower >> 6) | ((vUpper & 0xf) << 2);
break;
case 2:
// 4 msbs of lower as v lsbs
// 2 lsbs of upper as v msbs
v = (vLower >> 4) | ((vUpper & 0x3) << 4);
break;
case 3:
// 6 msbs of lower
v = (vLower >> 2);
break;
}
out[i * dims + j] = v;
}
}
} else if (bitsPerCode == 4) {
#pragma omp parallel for
for (int i = 0; i < numVecs; ++i) {
for (int j = 0; j < dims; ++j) {
int srcIdx = i * srcVecSize + (j / 2);
FAISS_ASSERT(srcIdx < data.size());
uint8_t v = data[srcIdx];
v = (j % 2 == 0) ? v & 0xf : v >> 4;
out[i * dims + j] = v;
}
}
} else {
// unhandled
FAISS_ASSERT(false);
}
return out;
}
template <typename T>
void
unpackInterleavedWord(const T* in,
T* out,
int numVecs,
int dims,
int bitsPerCode) {
int wordsPerDimBlock = 32 * bitsPerCode / (8 * sizeof(T));
int wordsPerBlock = wordsPerDimBlock * dims;
int numBlocks = utils::divUp(numVecs, 32);
#pragma omp parallel for
for (int i = 0; i < numVecs; ++i) {
int block = i / 32;
FAISS_ASSERT(block < numBlocks);
int lane = i % 32;
for (int j = 0; j < dims; ++j) {
int srcOffset = block * wordsPerBlock + j * wordsPerDimBlock + lane;
out[i * dims + j] = in[srcOffset];
}
}
}
std::vector<uint8_t>
unpackInterleaved(std::vector<uint8_t> data,
int numVecs,
int dims,
int bitsPerCode) {
int bytesPerDimBlock = 32 * bitsPerCode / 8;
int bytesPerBlock = bytesPerDimBlock * dims;
int numBlocks = utils::divUp(numVecs, 32);
size_t totalSize = (size_t) bytesPerBlock * numBlocks;
FAISS_ASSERT(data.size() == totalSize);
// bit codes padded to whole bytes
std::vector<uint8_t> out(numVecs * dims * utils::divUp(bitsPerCode, 8));
if (bitsPerCode == 8) {
unpackInterleavedWord<uint8_t>(data.data(), out.data(),
numVecs, dims, bitsPerCode);
} else if (bitsPerCode == 16) {
unpackInterleavedWord<uint16_t>((uint16_t*) data.data(),
(uint16_t*) out.data(),
numVecs, dims, bitsPerCode);
} else if (bitsPerCode == 32) {
unpackInterleavedWord<uint32_t>((uint32_t*) data.data(),
(uint32_t*) out.data(),
numVecs, dims, bitsPerCode);
} else if (bitsPerCode == 4) {
#pragma omp parallel for
for (int i = 0; i < numVecs; ++i) {
int block = i / 32;
int lane = i % 32;
int word = lane / 2;
int subWord = lane % 2;
for (int j = 0; j < dims; ++j) {
auto v =
data[block * bytesPerBlock + j * bytesPerDimBlock + word];
v = (subWord == 0) ? v & 0xf : v >> 4;
out[i * dims + j] = v;
}
}
} else if (bitsPerCode == 6) {
#pragma omp parallel for
for (int i = 0; i < numVecs; ++i) {
int block = i / 32;
int blockVector = i % 32;
for (int j = 0; j < dims; ++j) {
uint8_t* dimBlock =
&data[block * bytesPerBlock + j * bytesPerDimBlock];
int lo = (blockVector * 6) / 8;
int hi = lo + 1;
FAISS_ASSERT(lo < bytesPerDimBlock);
FAISS_ASSERT(hi <= bytesPerDimBlock);
auto vLower = dimBlock[lo];
auto vUpper = hi < bytesPerDimBlock ? dimBlock[hi] : 0;
uint8_t v = 0;
switch (blockVector % 4) {
case 0:
// 6 lsbs of lower
v = vLower & 0x3f;
break;
case 1:
// 2 msbs of lower as v lsbs
// 4 lsbs of upper as v msbs
v = (vLower >> 6) | ((vUpper & 0xf) << 2);
break;
case 2:
// 4 msbs of lower as v lsbs
// 2 lsbs of upper as v msbs
v = (vLower >> 4) | ((vUpper & 0x3) << 4);
break;
case 3:
// 6 msbs of lower
v = (vLower >> 2);
break;
}
out[i * dims + j] = v;
}
}
} else {
// unimplemented
FAISS_ASSERT(false);
}
return out;
}
std::vector<uint8_t>
packNonInterleaved(std::vector<uint8_t> data,
int numVecs,
int dims,
int bitsPerCode) {
// bit codes padded to whole bytes
FAISS_ASSERT(data.size() == numVecs * dims * utils::divUp(bitsPerCode, 8));
if (bitsPerCode == 8 ||
bitsPerCode == 16 ||
bitsPerCode == 32) {
// nothing to do, whole words are already where they need to be
return data;
}
// bits packed into a whole number of bytes
int bytesPerVec = utils::divUp(dims * bitsPerCode, 8);
std::vector<uint8_t> out(numVecs * bytesPerVec);
if (bitsPerCode == 4) {
#pragma omp parallel for
for (int i = 0; i < numVecs; ++i) {
for (int j = 0; j < bytesPerVec; ++j) {
int dimLo = j * 2;
int dimHi = dimLo + 1;
FAISS_ASSERT(dimLo < dims);
FAISS_ASSERT(dimHi <= dims);
uint8_t lo = data[i * dims + dimLo];
uint8_t hi = dimHi < dims ? data[i * dims + dimHi] : 0;
out[i * bytesPerVec + j] = (hi << 4) | (lo & 0xf);
}
}
} else if (bitsPerCode == 6) {
#pragma omp parallel for
for (int i = 0; i < numVecs; ++i) {
for (int j = 0; j < bytesPerVec; ++j) {
int dimLo = (j * 8) / 6;
int dimHi = dimLo + 1;
FAISS_ASSERT(dimLo < dims);
FAISS_ASSERT(dimHi <= dims);
uint8_t lo = data[i * dims + dimLo];
uint8_t hi = dimHi < dims ? data[i * dims + dimHi] : 0;
uint8_t v = 0;
// lsb ... msb
// 0: 0 0 0 0 0 0 1 1
// 1: 1 1 1 1 2 2 2 2
// 2: 2 2 3 3 3 3 3 3
switch (j % 3) {
case 0:
// 6 msbs of lower as vOut lsbs
// 2 lsbs of upper as vOut msbs
v = (lo & 0x3f) | (hi << 6);
break;
case 1:
// 4 msbs of lower as vOut lsbs
// 4 lsbs of upper as vOut msbs
v = (lo >> 2) | (hi << 4);
break;
case 2:
// 2 msbs of lower as vOut lsbs
// 6 lsbs of upper as vOut msbs
v = (lo >> 4) | (hi << 2);
break;
}
out[i * bytesPerVec + j] = v;
}
}
} else {
// unhandled
FAISS_ASSERT(false);
}
return out;
}
template <typename T>
void
packInterleavedWord(const T* in,
T* out,
int numVecs,
int dims,
int bitsPerCode) {
int wordsPerDimBlock = 32 * bitsPerCode / (8 * sizeof(T));
int wordsPerBlock = wordsPerDimBlock * dims;
int numBlocks = utils::divUp(numVecs, 32);
// We're guaranteed that all other slots not filled by the vectors present are
// initialized to zero (from the vector constructor in packInterleaved)
#pragma omp parallel for
for (int i = 0; i < numVecs; ++i) {
int block = i / 32;
FAISS_ASSERT(block < numBlocks);
int lane = i % 32;
for (int j = 0; j < dims; ++j) {
int dstOffset = block * wordsPerBlock + j * wordsPerDimBlock + lane;
out[dstOffset] = in[i * dims + j];
}
}
}
std::vector<uint8_t>
packInterleaved(std::vector<uint8_t> data,
int numVecs,
int dims,
int bitsPerCode) {
int bytesPerDimBlock = 32 * bitsPerCode / 8;
int bytesPerBlock = bytesPerDimBlock * dims;
int numBlocks = utils::divUp(numVecs, 32);
size_t totalSize = (size_t) bytesPerBlock * numBlocks;
// bit codes padded to whole bytes
FAISS_ASSERT(data.size() == numVecs * dims * utils::divUp(bitsPerCode, 8));
// packs based on blocks
std::vector<uint8_t> out(totalSize, 0);
if (bitsPerCode == 8) {
packInterleavedWord<uint8_t>(data.data(), out.data(),
numVecs, dims, bitsPerCode);
} else if (bitsPerCode == 16) {
packInterleavedWord<uint16_t>((uint16_t*) data.data(),
(uint16_t*) out.data(),
numVecs, dims, bitsPerCode);
} else if (bitsPerCode == 32) {
packInterleavedWord<uint32_t>((uint32_t*) data.data(),
(uint32_t*) out.data(),
numVecs, dims, bitsPerCode);
} else if (bitsPerCode == 4) {
#pragma omp parallel for
for (int i = 0; i < numBlocks; ++i) {
for (int j = 0; j < dims; ++j) {
for (int k = 0; k < bytesPerDimBlock; ++k) {
int loVec = i * 32 + k * 2;
int hiVec = loVec + 1;
uint8_t lo = loVec < numVecs ? data[loVec * dims + j] : 0;
uint8_t hi = hiVec < numVecs ? data[hiVec * dims + j] : 0;
out[i * bytesPerBlock + j * bytesPerDimBlock + k] =
(hi << 4) | (lo & 0xf);
}
}
}
} else if (bitsPerCode == 6) {
#pragma omp parallel for
for (int i = 0; i < numBlocks; ++i) {
for (int j = 0; j < dims; ++j) {
for (int k = 0; k < bytesPerDimBlock; ++k) {
// What input vectors we are pulling from
int loVec = i * 32 + (k * 8) / 6;
int hiVec = loVec + 1;
uint8_t lo = loVec < numVecs ? data[loVec * dims + j] : 0;
uint8_t hi = hiVec < numVecs ? data[hiVec * dims + j] : 0;
uint8_t v = 0;
// lsb ... msb
// 0: 0 0 0 0 0 0 1 1
// 1: 1 1 1 1 2 2 2 2
// 2: 2 2 3 3 3 3 3 3
switch (k % 3) {
case 0:
// 6 msbs of lower as vOut lsbs
// 2 lsbs of upper as vOut msbs
v = (lo & 0x3f) | (hi << 6);
break;
case 1:
// 4 msbs of lower as vOut lsbs
// 4 lsbs of upper as vOut msbs
v = (lo >> 2) | (hi << 4);
break;
case 2:
// 2 msbs of lower as vOut lsbs
// 6 lsbs of upper as vOut msbs
v = (lo >> 4) | (hi << 2);
break;
}
out[i * bytesPerBlock + j * bytesPerDimBlock + k] = v;
}
}
}
} else {
// unimplemented
FAISS_ASSERT(false);
}
return out;
}
} } // namespace