mirror of
https://github.com/facebookresearch/faiss.git
synced 2025-06-03 21:54:02 +08:00
CUDA 11 fixes + PQ training on the GPU
Summary: This diff exposes the ProductQuantizer `pq` object to the user for manipulation in Python just as `IndexIVFPQ` does. If no clustering index object is provided in `pq`, we create a `GpuIndexFlatL2` in order to perform the PQ training on the GPU as well. Also raises the error threshold a bit in some tests, as the previous ones seem to be triggered on a V100 GPU. Fixes an issue with AddException + (CUDA 11 and/or V100 GPUs) as well, where a `cudaMalloc` failure now seems to set state that is returned by `cudaGetLastError`. This we now clear before continuing. Fixes an issue (possible cuBLAS bug, following up with Nvidia): cublasSgemmEx in libcublas.so.11.1.0.229 returning CUBLAS_STATUS_NOT_SUPPORTED but would work fine in CUDA 9.2 (V100 GPU) cublasSgemmEx(handle, CUBLAS_OP_T, CUBLAS_OP_N, 64, 8, 64, &alpha, A, CUDA_R_16F, 64, B, CUDA_R_16F, 64, &beta, C, CUDA_R_32F, 64); Using cublasGemmEx with CUBLAS_COMPUTE_32F and CUBLAS_GEMM_DEFAULT would also fail, but using CUBLAS_COMPUTE_32F_PEDANTIC with cublasGemmEx succeeds. Using PEDANTIC for CUDA 11 + f16 arguments for now. Reviewed By: mdouze Differential Revision: D26331887 fbshipit-source-id: c65448c4c79b58dd49b0220b393056e431ef53c0
This commit is contained in:
parent
08a0ce72a2
commit
43ce2c93a4
@ -9,7 +9,6 @@
|
||||
#include <faiss/gpu/GpuIndexIVFPQ.h>
|
||||
#include <faiss/IndexFlat.h>
|
||||
#include <faiss/IndexIVFPQ.h>
|
||||
#include <faiss/impl/ProductQuantizer.h>
|
||||
#include <faiss/utils/utils.h>
|
||||
#include <faiss/gpu/GpuIndexFlat.h>
|
||||
#include <faiss/gpu/GpuResources.h>
|
||||
@ -30,6 +29,7 @@ GpuIndexIVFPQ::GpuIndexIVFPQ(GpuResourcesProvider* provider,
|
||||
index->metric_arg,
|
||||
index->nlist,
|
||||
config),
|
||||
pq(index->pq),
|
||||
ivfpqConfig_(config),
|
||||
usePrecomputedTables_(config.usePrecomputedTables),
|
||||
subQuantizers_(0),
|
||||
@ -51,6 +51,7 @@ GpuIndexIVFPQ::GpuIndexIVFPQ(GpuResourcesProvider* provider,
|
||||
0,
|
||||
nlist,
|
||||
config),
|
||||
pq(dims, subQuantizers, bitsPerCode),
|
||||
ivfpqConfig_(config),
|
||||
usePrecomputedTables_(config.usePrecomputedTables),
|
||||
subQuantizers_(subQuantizers),
|
||||
@ -74,6 +75,7 @@ GpuIndexIVFPQ::copyFrom(const faiss::IndexIVFPQ* index) {
|
||||
// Clear out our old data
|
||||
index_.reset();
|
||||
|
||||
pq = index->pq;
|
||||
subQuantizers_ = index->pq.M;
|
||||
bitsPerCode_ = index->pq.nbits;
|
||||
|
||||
@ -228,10 +230,6 @@ GpuIndexIVFPQ::reset() {
|
||||
|
||||
void
|
||||
GpuIndexIVFPQ::trainResidualQuantizer_(Index::idx_t n, const float* x) {
|
||||
// Just use the CPU product quantizer to determine sub-centroids
|
||||
faiss::ProductQuantizer pq(this->d, subQuantizers_, bitsPerCode_);
|
||||
pq.verbose = verbose;
|
||||
|
||||
// Code largely copied from faiss::IndexIVFPQ
|
||||
auto x_in = x;
|
||||
|
||||
@ -260,7 +258,26 @@ GpuIndexIVFPQ::trainResidualQuantizer_(Index::idx_t n, const float* x) {
|
||||
subQuantizers_, getCentroidsPerSubQuantizer(), n, this->d);
|
||||
}
|
||||
|
||||
pq.train(n, residuals.data());
|
||||
// For PQ training purposes, accelerate it by using a GPU clustering index if
|
||||
// a clustering index has not already been assigned
|
||||
if (!pq.assign_index) {
|
||||
try {
|
||||
GpuIndexFlatConfig config;
|
||||
config.device = ivfpqConfig_.device;
|
||||
GpuIndexFlatL2 pqIndex(resources_, pq.dsub, config);
|
||||
|
||||
pq.assign_index = &pqIndex;
|
||||
pq.train(n, residuals.data());
|
||||
} catch (...) {
|
||||
pq.assign_index = nullptr;
|
||||
throw;
|
||||
}
|
||||
|
||||
pq.assign_index = nullptr;
|
||||
} else {
|
||||
// use the currently assigned clustering index
|
||||
pq.train(n, residuals.data());
|
||||
}
|
||||
|
||||
index_.reset(new IVFPQ(resources_.get(),
|
||||
metric_type,
|
||||
|
@ -9,6 +9,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <faiss/gpu/GpuIndexIVF.h>
|
||||
#include <faiss/impl/ProductQuantizer.h>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
@ -127,6 +128,11 @@ class GpuIndexIVFPQ : public GpuIndexIVF {
|
||||
/// debugging purposes.
|
||||
std::vector<Index::idx_t> getListIndices(int listId) const override;
|
||||
|
||||
public:
|
||||
/// Like the CPU version, we expose a publically-visible ProductQuantizer for
|
||||
/// manipulation
|
||||
ProductQuantizer pq;
|
||||
|
||||
protected:
|
||||
/// Called from GpuIndex for add/add_with_ids
|
||||
void addImpl_(int n,
|
||||
|
@ -410,10 +410,6 @@ StandardGpuResourcesImpl::allocMemory(const AllocRequest& req) {
|
||||
|
||||
void* p = nullptr;
|
||||
|
||||
if (allocLogging_) {
|
||||
std::cout << "StandardGpuResources: alloc " << adjReq.toString() << "\n";
|
||||
}
|
||||
|
||||
if (adjReq.space == MemorySpace::Temporary) {
|
||||
// If we don't have enough space in our temporary memory manager, we need
|
||||
// to allocate this request separately
|
||||
@ -425,6 +421,11 @@ StandardGpuResourcesImpl::allocMemory(const AllocRequest& req) {
|
||||
newReq.space = MemorySpace::Device;
|
||||
newReq.type = AllocType::TemporaryMemoryOverflow;
|
||||
|
||||
if (allocLogging_) {
|
||||
std::cout << "StandardGpuResources: alloc fail " << adjReq.toString()
|
||||
<< " (no temp space); retrying as MemorySpace::Device\n";
|
||||
}
|
||||
|
||||
return allocMemory(newReq);
|
||||
}
|
||||
|
||||
@ -436,35 +437,53 @@ StandardGpuResourcesImpl::allocMemory(const AllocRequest& req) {
|
||||
|
||||
// Throw if we fail to allocate
|
||||
if (err != cudaSuccess) {
|
||||
auto& map = allocs_[req.device];
|
||||
// FIXME: as of CUDA 11, a memory allocation error appears to be presented
|
||||
// via cudaGetLastError as well, and needs to be cleared. Just call the
|
||||
// function to clear it
|
||||
cudaGetLastError();
|
||||
|
||||
std::stringstream ss;
|
||||
ss << "Failed to cudaMalloc " << adjReq.size << " bytes "
|
||||
<< "on device " << adjReq.device << " (error "
|
||||
<< (int) err << " " << cudaGetErrorString(err)
|
||||
<< "\nOutstanding allocations:\n" << allocsToString(map);
|
||||
ss << "StandardGpuResources: alloc fail " << adjReq.toString()
|
||||
<< " (cudaMalloc error "
|
||||
<< cudaGetErrorString(err) << " [" << (int) err << "])\n";
|
||||
auto str = ss.str();
|
||||
|
||||
if (allocLogging_) {
|
||||
std::cout << str;
|
||||
}
|
||||
|
||||
FAISS_THROW_IF_NOT_FMT(err == cudaSuccess, "%s", str.c_str());
|
||||
}
|
||||
} else if (adjReq.space == MemorySpace::Unified) {
|
||||
auto err = cudaMallocManaged(&p, adjReq.size);
|
||||
|
||||
if (err != cudaSuccess) {
|
||||
auto& map = allocs_[req.device];
|
||||
// FIXME: as of CUDA 11, a memory allocation error appears to be presented
|
||||
// via cudaGetLastError as well, and needs to be cleared. Just call the
|
||||
// function to clear it
|
||||
cudaGetLastError();
|
||||
|
||||
std::stringstream ss;
|
||||
ss << "Failed to cudaMallocManaged " << adjReq.size << " bytes "
|
||||
<< "(error " << (int) err << " " << cudaGetErrorString(err)
|
||||
<< "\nOutstanding allocations:\n" << allocsToString(map);
|
||||
ss << "StandardGpuResources: alloc fail " << adjReq.toString()
|
||||
<< " failed (cudaMallocManaged error "
|
||||
<< cudaGetErrorString(err) << " [" << (int) err << "])\n";
|
||||
auto str = ss.str();
|
||||
|
||||
if (allocLogging_) {
|
||||
std::cout << str;
|
||||
}
|
||||
|
||||
FAISS_THROW_IF_NOT_FMT(err == cudaSuccess, "%s", str.c_str());
|
||||
}
|
||||
} else {
|
||||
FAISS_ASSERT_FMT(false, "unknown MemorySpace %d", (int) adjReq.space);
|
||||
}
|
||||
|
||||
if (allocLogging_) {
|
||||
std::cout << "StandardGpuResources: alloc ok " << adjReq.toString()
|
||||
<< " ptr 0x" << p << "\n";
|
||||
}
|
||||
|
||||
allocs_[adjReq.device][p] = adjReq;
|
||||
|
||||
return p;
|
||||
|
@ -203,15 +203,15 @@ class TestInterleavedIVFPQLayout(unittest.TestCase):
|
||||
# Try without precomputed codes
|
||||
d_g, i_g = idx_gpu.search(xq, 10)
|
||||
d_c, i_c = idx_cpu.search(xq, 10)
|
||||
self.assertGreaterEqual((i_g == i_c).sum(), i_g.size - 10)
|
||||
self.assertTrue(np.allclose(d_g, d_c))
|
||||
self.assertGreaterEqual((i_g == i_c).sum(), i_g.size - 25)
|
||||
self.assertTrue(np.allclose(d_g, d_c, rtol=5e-5, atol=3e-1))
|
||||
|
||||
# Try with precomputed codes (different kernel)
|
||||
idx_gpu.setPrecomputedCodes(True)
|
||||
d_g, i_g = idx_gpu.search(xq, 10)
|
||||
d_c, i_c = idx_cpu.search(xq, 10)
|
||||
self.assertGreaterEqual((i_g == i_c).sum(), i_g.size - 10)
|
||||
self.assertTrue(np.allclose(d_g, d_c))
|
||||
self.assertGreaterEqual((i_g == i_c).sum(), i_g.size - 25)
|
||||
self.assertTrue(np.allclose(d_g, d_c, rtol=5e-5, atol=3e-1))
|
||||
|
||||
def test_copy_to_cpu(self):
|
||||
res = faiss.StandardGpuResources()
|
||||
@ -247,14 +247,14 @@ class TestInterleavedIVFPQLayout(unittest.TestCase):
|
||||
# Try without precomputed codes
|
||||
d_g, i_g = idx_gpu.search(xq, 10)
|
||||
d_c, i_c = idx_cpu.search(xq, 10)
|
||||
self.assertGreaterEqual((i_g == i_c).sum(), i_g.size - 10)
|
||||
self.assertGreaterEqual((i_g == i_c).sum(), i_g.size - 20)
|
||||
self.assertTrue(np.allclose(d_g, d_c))
|
||||
|
||||
# Try with precomputed codes (different kernel)
|
||||
idx_gpu.setPrecomputedCodes(True)
|
||||
d_g, i_g = idx_gpu.search(xq, 10)
|
||||
d_c, i_c = idx_cpu.search(xq, 10)
|
||||
self.assertGreaterEqual((i_g == i_c).sum(), i_g.size - 10)
|
||||
self.assertGreaterEqual((i_g == i_c).sum(), i_g.size - 20)
|
||||
self.assertTrue(np.allclose(d_g, d_c))
|
||||
|
||||
def test_copy_to_gpu(self):
|
||||
@ -291,14 +291,14 @@ class TestInterleavedIVFPQLayout(unittest.TestCase):
|
||||
# Try without precomputed codes
|
||||
d_g, i_g = idx_gpu.search(xq, 10)
|
||||
d_c, i_c = idx_cpu.search(xq, 10)
|
||||
self.assertGreaterEqual((i_g == i_c).sum(), i_g.size - 10)
|
||||
self.assertGreaterEqual((i_g == i_c).sum(), i_g.size - 20)
|
||||
self.assertTrue(np.allclose(d_g, d_c))
|
||||
|
||||
# Try with precomputed codes (different kernel)
|
||||
idx_gpu.setPrecomputedCodes(True)
|
||||
d_g, i_g = idx_gpu.search(xq, 10)
|
||||
d_c, i_c = idx_cpu.search(xq, 10)
|
||||
self.assertGreaterEqual((i_g == i_c).sum(), i_g.size - 10)
|
||||
self.assertGreaterEqual((i_g == i_c).sum(), i_g.size - 20)
|
||||
self.assertTrue(np.allclose(d_g, d_c))
|
||||
|
||||
|
||||
|
@ -48,6 +48,24 @@ rawGemm(cublasHandle_t handle,
|
||||
auto cAT = GetCudaType<AT>::Type;
|
||||
auto cBT = GetCudaType<BT>::Type;
|
||||
|
||||
// FIXME: some weird CUDA 11 bug? where cublasSgemmEx on
|
||||
// f16 (8, 64) x f16 (64, 64)' = f32 (8, 64) returns "not supported".
|
||||
// cublasGemmEx using CUBLAS_COMPUTE_32F also fails, but
|
||||
// CUBLAS_COMPUTE_32F_PEDANTIC does not fail (as seen on a V100).
|
||||
//
|
||||
// Only use the PEDANTIC implementation if the input matrices are f16
|
||||
// and we are on CUDA 11+
|
||||
#if CUDA_VERSION >= 11000
|
||||
if (cAT == CUDA_R_16F || cBT == CUDA_R_16F) {
|
||||
return cublasGemmEx(handle, transa, transb, m, n, k,
|
||||
&fAlpha, A, cAT, lda,
|
||||
B, cBT, ldb,
|
||||
&fBeta,
|
||||
C, CUDA_R_32F, ldc,
|
||||
CUBLAS_COMPUTE_32F_PEDANTIC, CUBLAS_GEMM_DEFAULT);
|
||||
}
|
||||
#endif
|
||||
|
||||
// Always accumulate in f32
|
||||
return cublasSgemmEx(handle, transa, transb, m, n, k,
|
||||
&fAlpha, A, cAT, lda,
|
||||
@ -155,11 +173,16 @@ runMatrixMult(Tensor<float, 2, true>& c, bool transC,
|
||||
|
||||
FAISS_ASSERT_FMT(err == CUBLAS_STATUS_SUCCESS,
|
||||
"cublas failed (%d): "
|
||||
"(%d, %d)%s x (%d, %d)%s = (%d, %d)%s",
|
||||
"(%d, %d)%s x (%d, %d)%s = (%d, %d)%s "
|
||||
"gemm params m %d n %d k %d trA %s trB %s lda %d ldb %d ldc %d",
|
||||
(int) err,
|
||||
a.getSize(0), a.getSize(1), transA ? "'" : "",
|
||||
b.getSize(0), b.getSize(1), transB ? "'" : "",
|
||||
c.getSize(0), c.getSize(1), transC ? "'" : "");
|
||||
c.getSize(0), c.getSize(1), transC ? "'" : "",
|
||||
m, n, k,
|
||||
gemmTrA == CUBLAS_OP_T ? "T" : "N",
|
||||
gemmTrB == CUBLAS_OP_T ? "T" : "N",
|
||||
lda, ldb, ldc);
|
||||
CUDA_TEST_ERROR();
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user