CUDA 11 fixes + PQ training on the GPU

Summary:
This diff exposes the ProductQuantizer `pq` object to the user for manipulation in Python just as `IndexIVFPQ` does.

If no clustering index object is provided in `pq`, we create a `GpuIndexFlatL2` in order to perform the PQ training on the GPU as well.

Also raises the error threshold a bit in some tests, as the previous ones seem to be triggered on a V100 GPU.

Fixes an issue with AddException + (CUDA 11 and/or V100 GPUs) as well, where a `cudaMalloc` failure now seems to set state that is returned by `cudaGetLastError`. This we now clear before continuing.

Fixes an issue (possible cuBLAS bug, following up with Nvidia):

cublasSgemmEx in libcublas.so.11.1.0.229 returning CUBLAS_STATUS_NOT_SUPPORTED but would work fine in CUDA 9.2 (V100 GPU)

cublasSgemmEx(handle, CUBLAS_OP_T, CUBLAS_OP_N,
64, 8, 64,
&alpha,
A, CUDA_R_16F, 64,
B, CUDA_R_16F, 64,
&beta,
C, CUDA_R_32F, 64);

Using cublasGemmEx with CUBLAS_COMPUTE_32F and CUBLAS_GEMM_DEFAULT would also fail, but using CUBLAS_COMPUTE_32F_PEDANTIC with cublasGemmEx succeeds. Using PEDANTIC for CUDA 11 + f16 arguments for now.

Reviewed By: mdouze

Differential Revision: D26331887

fbshipit-source-id: c65448c4c79b58dd49b0220b393056e431ef53c0
This commit is contained in:
Jeff Johnson 2021-02-10 15:21:33 -08:00 committed by Facebook GitHub Bot
parent 08a0ce72a2
commit 43ce2c93a4
5 changed files with 94 additions and 29 deletions

View File

@ -9,7 +9,6 @@
#include <faiss/gpu/GpuIndexIVFPQ.h>
#include <faiss/IndexFlat.h>
#include <faiss/IndexIVFPQ.h>
#include <faiss/impl/ProductQuantizer.h>
#include <faiss/utils/utils.h>
#include <faiss/gpu/GpuIndexFlat.h>
#include <faiss/gpu/GpuResources.h>
@ -30,6 +29,7 @@ GpuIndexIVFPQ::GpuIndexIVFPQ(GpuResourcesProvider* provider,
index->metric_arg,
index->nlist,
config),
pq(index->pq),
ivfpqConfig_(config),
usePrecomputedTables_(config.usePrecomputedTables),
subQuantizers_(0),
@ -51,6 +51,7 @@ GpuIndexIVFPQ::GpuIndexIVFPQ(GpuResourcesProvider* provider,
0,
nlist,
config),
pq(dims, subQuantizers, bitsPerCode),
ivfpqConfig_(config),
usePrecomputedTables_(config.usePrecomputedTables),
subQuantizers_(subQuantizers),
@ -74,6 +75,7 @@ GpuIndexIVFPQ::copyFrom(const faiss::IndexIVFPQ* index) {
// Clear out our old data
index_.reset();
pq = index->pq;
subQuantizers_ = index->pq.M;
bitsPerCode_ = index->pq.nbits;
@ -228,10 +230,6 @@ GpuIndexIVFPQ::reset() {
void
GpuIndexIVFPQ::trainResidualQuantizer_(Index::idx_t n, const float* x) {
// Just use the CPU product quantizer to determine sub-centroids
faiss::ProductQuantizer pq(this->d, subQuantizers_, bitsPerCode_);
pq.verbose = verbose;
// Code largely copied from faiss::IndexIVFPQ
auto x_in = x;
@ -260,7 +258,26 @@ GpuIndexIVFPQ::trainResidualQuantizer_(Index::idx_t n, const float* x) {
subQuantizers_, getCentroidsPerSubQuantizer(), n, this->d);
}
pq.train(n, residuals.data());
// For PQ training purposes, accelerate it by using a GPU clustering index if
// a clustering index has not already been assigned
if (!pq.assign_index) {
try {
GpuIndexFlatConfig config;
config.device = ivfpqConfig_.device;
GpuIndexFlatL2 pqIndex(resources_, pq.dsub, config);
pq.assign_index = &pqIndex;
pq.train(n, residuals.data());
} catch (...) {
pq.assign_index = nullptr;
throw;
}
pq.assign_index = nullptr;
} else {
// use the currently assigned clustering index
pq.train(n, residuals.data());
}
index_.reset(new IVFPQ(resources_.get(),
metric_type,

View File

@ -9,6 +9,7 @@
#pragma once
#include <faiss/gpu/GpuIndexIVF.h>
#include <faiss/impl/ProductQuantizer.h>
#include <memory>
#include <vector>
@ -127,6 +128,11 @@ class GpuIndexIVFPQ : public GpuIndexIVF {
/// debugging purposes.
std::vector<Index::idx_t> getListIndices(int listId) const override;
public:
/// Like the CPU version, we expose a publically-visible ProductQuantizer for
/// manipulation
ProductQuantizer pq;
protected:
/// Called from GpuIndex for add/add_with_ids
void addImpl_(int n,

View File

@ -410,10 +410,6 @@ StandardGpuResourcesImpl::allocMemory(const AllocRequest& req) {
void* p = nullptr;
if (allocLogging_) {
std::cout << "StandardGpuResources: alloc " << adjReq.toString() << "\n";
}
if (adjReq.space == MemorySpace::Temporary) {
// If we don't have enough space in our temporary memory manager, we need
// to allocate this request separately
@ -425,6 +421,11 @@ StandardGpuResourcesImpl::allocMemory(const AllocRequest& req) {
newReq.space = MemorySpace::Device;
newReq.type = AllocType::TemporaryMemoryOverflow;
if (allocLogging_) {
std::cout << "StandardGpuResources: alloc fail " << adjReq.toString()
<< " (no temp space); retrying as MemorySpace::Device\n";
}
return allocMemory(newReq);
}
@ -436,35 +437,53 @@ StandardGpuResourcesImpl::allocMemory(const AllocRequest& req) {
// Throw if we fail to allocate
if (err != cudaSuccess) {
auto& map = allocs_[req.device];
// FIXME: as of CUDA 11, a memory allocation error appears to be presented
// via cudaGetLastError as well, and needs to be cleared. Just call the
// function to clear it
cudaGetLastError();
std::stringstream ss;
ss << "Failed to cudaMalloc " << adjReq.size << " bytes "
<< "on device " << adjReq.device << " (error "
<< (int) err << " " << cudaGetErrorString(err)
<< "\nOutstanding allocations:\n" << allocsToString(map);
ss << "StandardGpuResources: alloc fail " << adjReq.toString()
<< " (cudaMalloc error "
<< cudaGetErrorString(err) << " [" << (int) err << "])\n";
auto str = ss.str();
if (allocLogging_) {
std::cout << str;
}
FAISS_THROW_IF_NOT_FMT(err == cudaSuccess, "%s", str.c_str());
}
} else if (adjReq.space == MemorySpace::Unified) {
auto err = cudaMallocManaged(&p, adjReq.size);
if (err != cudaSuccess) {
auto& map = allocs_[req.device];
// FIXME: as of CUDA 11, a memory allocation error appears to be presented
// via cudaGetLastError as well, and needs to be cleared. Just call the
// function to clear it
cudaGetLastError();
std::stringstream ss;
ss << "Failed to cudaMallocManaged " << adjReq.size << " bytes "
<< "(error " << (int) err << " " << cudaGetErrorString(err)
<< "\nOutstanding allocations:\n" << allocsToString(map);
ss << "StandardGpuResources: alloc fail " << adjReq.toString()
<< " failed (cudaMallocManaged error "
<< cudaGetErrorString(err) << " [" << (int) err << "])\n";
auto str = ss.str();
if (allocLogging_) {
std::cout << str;
}
FAISS_THROW_IF_NOT_FMT(err == cudaSuccess, "%s", str.c_str());
}
} else {
FAISS_ASSERT_FMT(false, "unknown MemorySpace %d", (int) adjReq.space);
}
if (allocLogging_) {
std::cout << "StandardGpuResources: alloc ok " << adjReq.toString()
<< " ptr 0x" << p << "\n";
}
allocs_[adjReq.device][p] = adjReq;
return p;

View File

@ -203,15 +203,15 @@ class TestInterleavedIVFPQLayout(unittest.TestCase):
# Try without precomputed codes
d_g, i_g = idx_gpu.search(xq, 10)
d_c, i_c = idx_cpu.search(xq, 10)
self.assertGreaterEqual((i_g == i_c).sum(), i_g.size - 10)
self.assertTrue(np.allclose(d_g, d_c))
self.assertGreaterEqual((i_g == i_c).sum(), i_g.size - 25)
self.assertTrue(np.allclose(d_g, d_c, rtol=5e-5, atol=3e-1))
# Try with precomputed codes (different kernel)
idx_gpu.setPrecomputedCodes(True)
d_g, i_g = idx_gpu.search(xq, 10)
d_c, i_c = idx_cpu.search(xq, 10)
self.assertGreaterEqual((i_g == i_c).sum(), i_g.size - 10)
self.assertTrue(np.allclose(d_g, d_c))
self.assertGreaterEqual((i_g == i_c).sum(), i_g.size - 25)
self.assertTrue(np.allclose(d_g, d_c, rtol=5e-5, atol=3e-1))
def test_copy_to_cpu(self):
res = faiss.StandardGpuResources()
@ -247,14 +247,14 @@ class TestInterleavedIVFPQLayout(unittest.TestCase):
# Try without precomputed codes
d_g, i_g = idx_gpu.search(xq, 10)
d_c, i_c = idx_cpu.search(xq, 10)
self.assertGreaterEqual((i_g == i_c).sum(), i_g.size - 10)
self.assertGreaterEqual((i_g == i_c).sum(), i_g.size - 20)
self.assertTrue(np.allclose(d_g, d_c))
# Try with precomputed codes (different kernel)
idx_gpu.setPrecomputedCodes(True)
d_g, i_g = idx_gpu.search(xq, 10)
d_c, i_c = idx_cpu.search(xq, 10)
self.assertGreaterEqual((i_g == i_c).sum(), i_g.size - 10)
self.assertGreaterEqual((i_g == i_c).sum(), i_g.size - 20)
self.assertTrue(np.allclose(d_g, d_c))
def test_copy_to_gpu(self):
@ -291,14 +291,14 @@ class TestInterleavedIVFPQLayout(unittest.TestCase):
# Try without precomputed codes
d_g, i_g = idx_gpu.search(xq, 10)
d_c, i_c = idx_cpu.search(xq, 10)
self.assertGreaterEqual((i_g == i_c).sum(), i_g.size - 10)
self.assertGreaterEqual((i_g == i_c).sum(), i_g.size - 20)
self.assertTrue(np.allclose(d_g, d_c))
# Try with precomputed codes (different kernel)
idx_gpu.setPrecomputedCodes(True)
d_g, i_g = idx_gpu.search(xq, 10)
d_c, i_c = idx_cpu.search(xq, 10)
self.assertGreaterEqual((i_g == i_c).sum(), i_g.size - 10)
self.assertGreaterEqual((i_g == i_c).sum(), i_g.size - 20)
self.assertTrue(np.allclose(d_g, d_c))

View File

@ -48,6 +48,24 @@ rawGemm(cublasHandle_t handle,
auto cAT = GetCudaType<AT>::Type;
auto cBT = GetCudaType<BT>::Type;
// FIXME: some weird CUDA 11 bug? where cublasSgemmEx on
// f16 (8, 64) x f16 (64, 64)' = f32 (8, 64) returns "not supported".
// cublasGemmEx using CUBLAS_COMPUTE_32F also fails, but
// CUBLAS_COMPUTE_32F_PEDANTIC does not fail (as seen on a V100).
//
// Only use the PEDANTIC implementation if the input matrices are f16
// and we are on CUDA 11+
#if CUDA_VERSION >= 11000
if (cAT == CUDA_R_16F || cBT == CUDA_R_16F) {
return cublasGemmEx(handle, transa, transb, m, n, k,
&fAlpha, A, cAT, lda,
B, cBT, ldb,
&fBeta,
C, CUDA_R_32F, ldc,
CUBLAS_COMPUTE_32F_PEDANTIC, CUBLAS_GEMM_DEFAULT);
}
#endif
// Always accumulate in f32
return cublasSgemmEx(handle, transa, transb, m, n, k,
&fAlpha, A, cAT, lda,
@ -155,11 +173,16 @@ runMatrixMult(Tensor<float, 2, true>& c, bool transC,
FAISS_ASSERT_FMT(err == CUBLAS_STATUS_SUCCESS,
"cublas failed (%d): "
"(%d, %d)%s x (%d, %d)%s = (%d, %d)%s",
"(%d, %d)%s x (%d, %d)%s = (%d, %d)%s "
"gemm params m %d n %d k %d trA %s trB %s lda %d ldb %d ldc %d",
(int) err,
a.getSize(0), a.getSize(1), transA ? "'" : "",
b.getSize(0), b.getSize(1), transB ? "'" : "",
c.getSize(0), c.getSize(1), transC ? "'" : "");
c.getSize(0), c.getSize(1), transC ? "'" : "",
m, n, k,
gemmTrA == CUBLAS_OP_T ? "T" : "N",
gemmTrB == CUBLAS_OP_T ? "T" : "N",
lda, ldb, ldc);
CUDA_TEST_ERROR();
}