11 #include "DeviceUtils.h"
12 #include "../../FaissAssert.h"
14 #include <unordered_map>
16 namespace faiss {
namespace gpu {
18 int getCurrentDevice() {
20 CUDA_VERIFY(cudaGetDevice(&dev));
21 FAISS_ASSERT(dev != -1);
26 void setCurrentDevice(
int device) {
27 CUDA_VERIFY(cudaSetDevice(device));
32 CUDA_VERIFY(cudaGetDeviceCount(&numDev));
33 FAISS_ASSERT(numDev != -1);
38 void synchronizeAllDevices() {
39 for (
int i = 0; i < getNumDevices(); ++i) {
42 CUDA_VERIFY(cudaDeviceSynchronize());
46 const cudaDeviceProp& getDeviceProperties(
int device) {
47 static std::mutex mutex;
48 static std::unordered_map<int, cudaDeviceProp> properties;
50 std::lock_guard<std::mutex> guard(mutex);
52 auto it = properties.find(device);
53 if (it == properties.end()) {
55 CUDA_VERIFY(cudaGetDeviceProperties(&prop, device));
57 properties[device] = prop;
58 it = properties.find(device);
64 const cudaDeviceProp& getCurrentDeviceProperties() {
65 return getDeviceProperties(getCurrentDevice());
68 int getMaxThreads(
int device) {
69 return getDeviceProperties(device).maxThreadsPerBlock;
72 int getMaxThreadsCurrentDevice() {
73 return getMaxThreads(getCurrentDevice());
76 size_t getMaxSharedMemPerBlock(
int device) {
77 return getDeviceProperties(device).sharedMemPerBlock;
80 size_t getMaxSharedMemPerBlockCurrentDevice() {
81 return getMaxSharedMemPerBlock(getCurrentDevice());
84 int getDeviceForAddress(
const void* p) {
89 cudaPointerAttributes att;
90 cudaError_t err = cudaPointerGetAttributes(&att, p);
91 FAISS_ASSERT(err == cudaSuccess ||
92 err == cudaErrorInvalidValue);
94 if (err == cudaErrorInvalidValue) {
96 err = cudaGetLastError();
97 FAISS_ASSERT(err == cudaErrorInvalidValue);
99 }
else if (att.memoryType == cudaMemoryTypeHost) {
106 bool getFullUnifiedMemSupport(
int device) {
107 const auto& prop = getDeviceProperties(device);
108 return (prop.major >= 6);
111 bool getFullUnifiedMemSupportCurrentDevice() {
112 return getFullUnifiedMemSupport(getCurrentDevice());
115 DeviceScope::DeviceScope(
int device) {
116 prevDevice_ = getCurrentDevice();
118 if (prevDevice_ != device) {
119 setCurrentDevice(device);
125 DeviceScope::~DeviceScope() {
126 if (prevDevice_ != -1) {
127 setCurrentDevice(prevDevice_);
131 CublasHandleScope::CublasHandleScope() {
132 auto blasStatus = cublasCreate(&blasHandle_);
133 FAISS_ASSERT(blasStatus == CUBLAS_STATUS_SUCCESS);
136 CublasHandleScope::~CublasHandleScope() {
137 auto blasStatus = cublasDestroy(blasHandle_);
138 FAISS_ASSERT(blasStatus == CUBLAS_STATUS_SUCCESS);
143 CUDA_VERIFY(cudaEventCreateWithFlags(&event_, cudaEventDisableTiming));
144 CUDA_VERIFY(cudaEventRecord(event_, stream));
148 : event_(std::move(event.event_)) {
152 CudaEvent::~CudaEvent() {
154 CUDA_VERIFY(cudaEventDestroy(event_));
159 CudaEvent::operator=(CudaEvent&& event) noexcept {
160 event_ = std::move(event.event_);
168 CUDA_VERIFY(cudaStreamWaitEvent(stream, event_, 0));
173 CUDA_VERIFY(cudaEventSynchronize(event_));
void cpuWaitOnEvent()
Have the CPU wait for the completion of this event.
void streamWaitOnEvent(cudaStream_t stream)
Wait on this event in this stream.
CudaEvent(cudaStream_t stream)
Creates an event and records it in this stream.