411 lines
12 KiB
C++
411 lines
12 KiB
C++
/**
|
|
* Copyright (c) Facebook, Inc. and its affiliates.
|
|
*
|
|
* This source code is licensed under the MIT license found in the
|
|
* LICENSE file in the root directory of this source tree.
|
|
*/
|
|
|
|
#include <faiss/gpu/impl/InterleavedCodes.h>
|
|
#include <faiss/impl/FaissAssert.h>
|
|
#include <faiss/gpu/utils/StaticUtils.h>
|
|
|
|
namespace faiss { namespace gpu {
|
|
|
|
std::vector<uint8_t>
|
|
unpackNonInterleaved(std::vector<uint8_t> data,
|
|
int numVecs,
|
|
int dims,
|
|
int bitsPerCode) {
|
|
int srcVecSize = utils::divUp(dims * bitsPerCode, 8);
|
|
FAISS_ASSERT(data.size() == numVecs * srcVecSize);
|
|
|
|
if (bitsPerCode == 8 ||
|
|
bitsPerCode == 16 ||
|
|
bitsPerCode == 32) {
|
|
// nothing to do
|
|
return data;
|
|
}
|
|
|
|
// bit codes padded to whole bytes
|
|
std::vector<uint8_t> out(numVecs * dims * utils::divUp(bitsPerCode, 8));
|
|
|
|
if (bitsPerCode == 6) {
|
|
#pragma omp parallel for
|
|
for (int i = 0; i < numVecs; ++i) {
|
|
for (int j = 0; j < dims; ++j) {
|
|
int lo = i * srcVecSize + (j * 6) / 8;
|
|
int hi = lo + 1;
|
|
|
|
FAISS_ASSERT(lo < data.size());
|
|
FAISS_ASSERT(hi <= data.size());
|
|
|
|
auto vLower = data[lo];
|
|
auto vUpper = hi < data.size() ? data[hi] : 0;
|
|
|
|
uint8_t v = 0;
|
|
|
|
switch (j % 4) {
|
|
case 0:
|
|
// 6 lsbs of lower
|
|
v = vLower & 0x3f;
|
|
break;
|
|
case 1:
|
|
// 2 msbs of lower as v lsbs
|
|
// 4 lsbs of upper as v msbs
|
|
v = (vLower >> 6) | ((vUpper & 0xf) << 2);
|
|
break;
|
|
case 2:
|
|
// 4 msbs of lower as v lsbs
|
|
// 2 lsbs of upper as v msbs
|
|
v = (vLower >> 4) | ((vUpper & 0x3) << 4);
|
|
break;
|
|
case 3:
|
|
// 6 msbs of lower
|
|
v = (vLower >> 2);
|
|
break;
|
|
}
|
|
|
|
out[i * dims + j] = v;
|
|
}
|
|
}
|
|
} else if (bitsPerCode == 4) {
|
|
#pragma omp parallel for
|
|
for (int i = 0; i < numVecs; ++i) {
|
|
for (int j = 0; j < dims; ++j) {
|
|
int srcIdx = i * srcVecSize + (j / 2);
|
|
FAISS_ASSERT(srcIdx < data.size());
|
|
|
|
uint8_t v = data[srcIdx];
|
|
v = (j % 2 == 0) ? v & 0xf : v >> 4;
|
|
|
|
out[i * dims + j] = v;
|
|
}
|
|
}
|
|
} else {
|
|
// unhandled
|
|
FAISS_ASSERT(false);
|
|
}
|
|
|
|
return out;
|
|
}
|
|
|
|
template <typename T>
|
|
void
|
|
unpackInterleavedWord(const T* in,
|
|
T* out,
|
|
int numVecs,
|
|
int dims,
|
|
int bitsPerCode) {
|
|
int wordsPerDimBlock = 32 * bitsPerCode / (8 * sizeof(T));
|
|
int wordsPerBlock = wordsPerDimBlock * dims;
|
|
int numBlocks = utils::divUp(numVecs, 32);
|
|
|
|
#pragma omp parallel for
|
|
for (int i = 0; i < numVecs; ++i) {
|
|
int block = i / 32;
|
|
FAISS_ASSERT(block < numBlocks);
|
|
int lane = i % 32;
|
|
|
|
for (int j = 0; j < dims; ++j) {
|
|
int srcOffset = block * wordsPerBlock + j * wordsPerDimBlock + lane;
|
|
out[i * dims + j] = in[srcOffset];
|
|
}
|
|
}
|
|
}
|
|
|
|
std::vector<uint8_t>
|
|
unpackInterleaved(std::vector<uint8_t> data,
|
|
int numVecs,
|
|
int dims,
|
|
int bitsPerCode) {
|
|
int bytesPerDimBlock = 32 * bitsPerCode / 8;
|
|
int bytesPerBlock = bytesPerDimBlock * dims;
|
|
int numBlocks = utils::divUp(numVecs, 32);
|
|
size_t totalSize = (size_t) bytesPerBlock * numBlocks;
|
|
FAISS_ASSERT(data.size() == totalSize);
|
|
|
|
// bit codes padded to whole bytes
|
|
std::vector<uint8_t> out(numVecs * dims * utils::divUp(bitsPerCode, 8));
|
|
|
|
if (bitsPerCode == 8) {
|
|
unpackInterleavedWord<uint8_t>(data.data(), out.data(),
|
|
numVecs, dims, bitsPerCode);
|
|
} else if (bitsPerCode == 16) {
|
|
unpackInterleavedWord<uint16_t>((uint16_t*) data.data(),
|
|
(uint16_t*) out.data(),
|
|
numVecs, dims, bitsPerCode);
|
|
} else if (bitsPerCode == 32) {
|
|
unpackInterleavedWord<uint32_t>((uint32_t*) data.data(),
|
|
(uint32_t*) out.data(),
|
|
numVecs, dims, bitsPerCode);
|
|
} else if (bitsPerCode == 4) {
|
|
#pragma omp parallel for
|
|
for (int i = 0; i < numVecs; ++i) {
|
|
int block = i / 32;
|
|
int lane = i % 32;
|
|
|
|
int word = lane / 2;
|
|
int subWord = lane % 2;
|
|
|
|
for (int j = 0; j < dims; ++j) {
|
|
auto v =
|
|
data[block * bytesPerBlock + j * bytesPerDimBlock + word];
|
|
|
|
v = (subWord == 0) ? v & 0xf : v >> 4;
|
|
out[i * dims + j] = v;
|
|
}
|
|
}
|
|
} else if (bitsPerCode == 6) {
|
|
#pragma omp parallel for
|
|
for (int i = 0; i < numVecs; ++i) {
|
|
int block = i / 32;
|
|
int blockVector = i % 32;
|
|
|
|
for (int j = 0; j < dims; ++j) {
|
|
uint8_t* dimBlock =
|
|
&data[block * bytesPerBlock + j * bytesPerDimBlock];
|
|
|
|
int lo = (blockVector * 6) / 8;
|
|
int hi = lo + 1;
|
|
|
|
FAISS_ASSERT(lo < bytesPerDimBlock);
|
|
FAISS_ASSERT(hi <= bytesPerDimBlock);
|
|
|
|
auto vLower = dimBlock[lo];
|
|
auto vUpper = hi < bytesPerDimBlock ? dimBlock[hi] : 0;
|
|
|
|
uint8_t v = 0;
|
|
|
|
switch (blockVector % 4) {
|
|
case 0:
|
|
// 6 lsbs of lower
|
|
v = vLower & 0x3f;
|
|
break;
|
|
case 1:
|
|
// 2 msbs of lower as v lsbs
|
|
// 4 lsbs of upper as v msbs
|
|
v = (vLower >> 6) | ((vUpper & 0xf) << 2);
|
|
break;
|
|
case 2:
|
|
// 4 msbs of lower as v lsbs
|
|
// 2 lsbs of upper as v msbs
|
|
v = (vLower >> 4) | ((vUpper & 0x3) << 4);
|
|
break;
|
|
case 3:
|
|
// 6 msbs of lower
|
|
v = (vLower >> 2);
|
|
break;
|
|
}
|
|
|
|
out[i * dims + j] = v;
|
|
}
|
|
}
|
|
} else {
|
|
// unimplemented
|
|
FAISS_ASSERT(false);
|
|
}
|
|
|
|
return out;
|
|
}
|
|
|
|
std::vector<uint8_t>
|
|
packNonInterleaved(std::vector<uint8_t> data,
|
|
int numVecs,
|
|
int dims,
|
|
int bitsPerCode) {
|
|
// bit codes padded to whole bytes
|
|
FAISS_ASSERT(data.size() == numVecs * dims * utils::divUp(bitsPerCode, 8));
|
|
|
|
if (bitsPerCode == 8 ||
|
|
bitsPerCode == 16 ||
|
|
bitsPerCode == 32) {
|
|
// nothing to do, whole words are already where they need to be
|
|
return data;
|
|
}
|
|
|
|
// bits packed into a whole number of bytes
|
|
int bytesPerVec = utils::divUp(dims * bitsPerCode, 8);
|
|
|
|
std::vector<uint8_t> out(numVecs * bytesPerVec);
|
|
|
|
if (bitsPerCode == 4) {
|
|
#pragma omp parallel for
|
|
for (int i = 0; i < numVecs; ++i) {
|
|
for (int j = 0; j < bytesPerVec; ++j) {
|
|
int dimLo = j * 2;
|
|
int dimHi = dimLo + 1;
|
|
FAISS_ASSERT(dimLo < dims);
|
|
FAISS_ASSERT(dimHi <= dims);
|
|
|
|
uint8_t lo = data[i * dims + dimLo];
|
|
uint8_t hi = dimHi < dims ? data[i * dims + dimHi] : 0;
|
|
|
|
out[i * bytesPerVec + j] = (hi << 4) | (lo & 0xf);
|
|
}
|
|
}
|
|
} else if (bitsPerCode == 6) {
|
|
#pragma omp parallel for
|
|
for (int i = 0; i < numVecs; ++i) {
|
|
for (int j = 0; j < bytesPerVec; ++j) {
|
|
int dimLo = (j * 8) / 6;
|
|
int dimHi = dimLo + 1;
|
|
FAISS_ASSERT(dimLo < dims);
|
|
FAISS_ASSERT(dimHi <= dims);
|
|
|
|
uint8_t lo = data[i * dims + dimLo];
|
|
uint8_t hi = dimHi < dims ? data[i * dims + dimHi] : 0;
|
|
|
|
uint8_t v = 0;
|
|
|
|
// lsb ... msb
|
|
// 0: 0 0 0 0 0 0 1 1
|
|
// 1: 1 1 1 1 2 2 2 2
|
|
// 2: 2 2 3 3 3 3 3 3
|
|
switch (j % 3) {
|
|
case 0:
|
|
// 6 msbs of lower as vOut lsbs
|
|
// 2 lsbs of upper as vOut msbs
|
|
v = (lo & 0x3f) | (hi << 6);
|
|
break;
|
|
case 1:
|
|
// 4 msbs of lower as vOut lsbs
|
|
// 4 lsbs of upper as vOut msbs
|
|
v = (lo >> 2) | (hi << 4);
|
|
break;
|
|
case 2:
|
|
// 2 msbs of lower as vOut lsbs
|
|
// 6 lsbs of upper as vOut msbs
|
|
v = (lo >> 4) | (hi << 2);
|
|
break;
|
|
}
|
|
|
|
out[i * bytesPerVec + j] = v;
|
|
}
|
|
}
|
|
} else {
|
|
// unhandled
|
|
FAISS_ASSERT(false);
|
|
}
|
|
|
|
return out;
|
|
}
|
|
|
|
template <typename T>
|
|
void
|
|
packInterleavedWord(const T* in,
|
|
T* out,
|
|
int numVecs,
|
|
int dims,
|
|
int bitsPerCode) {
|
|
int wordsPerDimBlock = 32 * bitsPerCode / (8 * sizeof(T));
|
|
int wordsPerBlock = wordsPerDimBlock * dims;
|
|
int numBlocks = utils::divUp(numVecs, 32);
|
|
|
|
// We're guaranteed that all other slots not filled by the vectors present are
|
|
// initialized to zero (from the vector constructor in packInterleaved)
|
|
#pragma omp parallel for
|
|
for (int i = 0; i < numVecs; ++i) {
|
|
int block = i / 32;
|
|
FAISS_ASSERT(block < numBlocks);
|
|
int lane = i % 32;
|
|
|
|
for (int j = 0; j < dims; ++j) {
|
|
int dstOffset = block * wordsPerBlock + j * wordsPerDimBlock + lane;
|
|
out[dstOffset] = in[i * dims + j];
|
|
}
|
|
}
|
|
}
|
|
|
|
std::vector<uint8_t>
|
|
packInterleaved(std::vector<uint8_t> data,
|
|
int numVecs,
|
|
int dims,
|
|
int bitsPerCode) {
|
|
int bytesPerDimBlock = 32 * bitsPerCode / 8;
|
|
int bytesPerBlock = bytesPerDimBlock * dims;
|
|
int numBlocks = utils::divUp(numVecs, 32);
|
|
size_t totalSize = (size_t) bytesPerBlock * numBlocks;
|
|
|
|
// bit codes padded to whole bytes
|
|
FAISS_ASSERT(data.size() == numVecs * dims * utils::divUp(bitsPerCode, 8));
|
|
|
|
// packs based on blocks
|
|
std::vector<uint8_t> out(totalSize, 0);
|
|
|
|
if (bitsPerCode == 8) {
|
|
packInterleavedWord<uint8_t>(data.data(), out.data(),
|
|
numVecs, dims, bitsPerCode);
|
|
} else if (bitsPerCode == 16) {
|
|
packInterleavedWord<uint16_t>((uint16_t*) data.data(),
|
|
(uint16_t*) out.data(),
|
|
numVecs, dims, bitsPerCode);
|
|
} else if (bitsPerCode == 32) {
|
|
packInterleavedWord<uint32_t>((uint32_t*) data.data(),
|
|
(uint32_t*) out.data(),
|
|
numVecs, dims, bitsPerCode);
|
|
} else if (bitsPerCode == 4) {
|
|
#pragma omp parallel for
|
|
for (int i = 0; i < numBlocks; ++i) {
|
|
for (int j = 0; j < dims; ++j) {
|
|
for (int k = 0; k < bytesPerDimBlock; ++k) {
|
|
int loVec = i * 32 + k * 2;
|
|
int hiVec = loVec + 1;
|
|
|
|
uint8_t lo = loVec < numVecs ? data[loVec * dims + j] : 0;
|
|
uint8_t hi = hiVec < numVecs ? data[hiVec * dims + j] : 0;
|
|
|
|
out[i * bytesPerBlock + j * bytesPerDimBlock + k] =
|
|
(hi << 4) | (lo & 0xf);
|
|
}
|
|
}
|
|
}
|
|
} else if (bitsPerCode == 6) {
|
|
#pragma omp parallel for
|
|
for (int i = 0; i < numBlocks; ++i) {
|
|
for (int j = 0; j < dims; ++j) {
|
|
for (int k = 0; k < bytesPerDimBlock; ++k) {
|
|
// What input vectors we are pulling from
|
|
int loVec = i * 32 + (k * 8) / 6;
|
|
int hiVec = loVec + 1;
|
|
|
|
uint8_t lo = loVec < numVecs ? data[loVec * dims + j] : 0;
|
|
uint8_t hi = hiVec < numVecs ? data[hiVec * dims + j] : 0;
|
|
|
|
uint8_t v = 0;
|
|
|
|
// lsb ... msb
|
|
// 0: 0 0 0 0 0 0 1 1
|
|
// 1: 1 1 1 1 2 2 2 2
|
|
// 2: 2 2 3 3 3 3 3 3
|
|
switch (k % 3) {
|
|
case 0:
|
|
// 6 msbs of lower as vOut lsbs
|
|
// 2 lsbs of upper as vOut msbs
|
|
v = (lo & 0x3f) | (hi << 6);
|
|
break;
|
|
case 1:
|
|
// 4 msbs of lower as vOut lsbs
|
|
// 4 lsbs of upper as vOut msbs
|
|
v = (lo >> 2) | (hi << 4);
|
|
break;
|
|
case 2:
|
|
// 2 msbs of lower as vOut lsbs
|
|
// 6 lsbs of upper as vOut msbs
|
|
v = (lo >> 4) | (hi << 2);
|
|
break;
|
|
}
|
|
|
|
out[i * bytesPerBlock + j * bytesPerDimBlock + k] = v;
|
|
}
|
|
}
|
|
}
|
|
} else {
|
|
// unimplemented
|
|
FAISS_ASSERT(false);
|
|
}
|
|
|
|
return out;
|
|
}
|
|
|
|
} } // namespace
|