12 #include "DeviceUtils.h"
13 #include "../../FaissAssert.h"
15 #include <unordered_map>
17 namespace faiss {
namespace gpu {
19 int getCurrentDevice() {
21 CUDA_VERIFY(cudaGetDevice(&dev));
22 FAISS_ASSERT(dev != -1);
27 void setCurrentDevice(
int device) {
28 CUDA_VERIFY(cudaSetDevice(device));
33 CUDA_VERIFY(cudaGetDeviceCount(&numDev));
34 FAISS_ASSERT(numDev != -1);
39 void synchronizeAllDevices() {
40 for (
int i = 0; i < getNumDevices(); ++i) {
43 CUDA_VERIFY(cudaDeviceSynchronize());
47 cudaDeviceProp& getDeviceProperties(
int device) {
48 static std::mutex mutex;
49 static std::unordered_map<int, cudaDeviceProp> properties;
51 std::lock_guard<std::mutex> guard(mutex);
53 auto it = properties.find(device);
54 if (it == properties.end()) {
56 CUDA_VERIFY(cudaGetDeviceProperties(&prop, device));
58 properties[device] = prop;
59 it = properties.find(device);
65 int getMaxThreads(
int device) {
66 return getDeviceProperties(device).maxThreadsPerBlock;
69 int getMaxThreadsCurrentDevice() {
70 return getMaxThreads(getCurrentDevice());
73 size_t getMaxSharedMemPerBlock(
int device) {
74 return getDeviceProperties(device).sharedMemPerBlock;
77 size_t getMaxSharedMemPerBlockCurrentDevice() {
78 return getMaxSharedMemPerBlock(getCurrentDevice());
81 int getDeviceForAddress(
const void* p) {
86 cudaPointerAttributes att;
87 cudaError_t err = cudaPointerGetAttributes(&att, p);
88 FAISS_ASSERT(err == cudaSuccess ||
89 err == cudaErrorInvalidValue);
91 if (err == cudaErrorInvalidValue) {
93 err = cudaGetLastError();
94 FAISS_ASSERT(err == cudaErrorInvalidValue);
96 }
else if (att.memoryType == cudaMemoryTypeHost) {
103 DeviceScope::DeviceScope(
int device) {
104 prevDevice_ = getCurrentDevice();
106 if (prevDevice_ != device) {
107 setCurrentDevice(device);
113 DeviceScope::~DeviceScope() {
114 if (prevDevice_ != -1) {
115 setCurrentDevice(prevDevice_);
119 CublasHandleScope::CublasHandleScope() {
120 auto blasStatus = cublasCreate(&blasHandle_);
121 FAISS_ASSERT(blasStatus == CUBLAS_STATUS_SUCCESS);
124 CublasHandleScope::~CublasHandleScope() {
125 auto blasStatus = cublasDestroy(blasHandle_);
126 FAISS_ASSERT(blasStatus == CUBLAS_STATUS_SUCCESS);
131 CUDA_VERIFY(cudaEventCreateWithFlags(&event_, cudaEventDisableTiming));
132 CUDA_VERIFY(cudaEventRecord(event_, stream));
136 : event_(std::move(event.event_)) {
140 CudaEvent::~CudaEvent() {
142 CUDA_VERIFY(cudaEventDestroy(event_));
147 CudaEvent::operator=(CudaEvent&& event) noexcept {
148 event_ = std::move(event.event_);
156 CUDA_VERIFY(cudaStreamWaitEvent(stream, event_, 0));
161 CUDA_VERIFY(cudaEventSynchronize(event_));
void cpuWaitOnEvent()
Have the CPU wait for the completion of this event.
void streamWaitOnEvent(cudaStream_t stream)
Wait on this event in this stream.
CudaEvent(cudaStream_t stream)
Creates an event and records it in this stream.