10 #include "DeviceUtils.h"
11 #include "../../FaissAssert.h"
13 #include <unordered_map>
15 namespace faiss {
namespace gpu {
17 int getCurrentDevice() {
19 CUDA_VERIFY(cudaGetDevice(&dev));
20 FAISS_ASSERT(dev != -1);
25 void setCurrentDevice(
int device) {
26 CUDA_VERIFY(cudaSetDevice(device));
31 CUDA_VERIFY(cudaGetDeviceCount(&numDev));
32 FAISS_ASSERT(numDev != -1);
37 void synchronizeAllDevices() {
38 for (
int i = 0; i < getNumDevices(); ++i) {
41 CUDA_VERIFY(cudaDeviceSynchronize());
45 const cudaDeviceProp& getDeviceProperties(
int device) {
46 static std::mutex mutex;
47 static std::unordered_map<int, cudaDeviceProp> properties;
49 std::lock_guard<std::mutex> guard(mutex);
51 auto it = properties.find(device);
52 if (it == properties.end()) {
54 CUDA_VERIFY(cudaGetDeviceProperties(&prop, device));
56 properties[device] = prop;
57 it = properties.find(device);
63 const cudaDeviceProp& getCurrentDeviceProperties() {
64 return getDeviceProperties(getCurrentDevice());
67 int getMaxThreads(
int device) {
68 return getDeviceProperties(device).maxThreadsPerBlock;
71 int getMaxThreadsCurrentDevice() {
72 return getMaxThreads(getCurrentDevice());
75 size_t getMaxSharedMemPerBlock(
int device) {
76 return getDeviceProperties(device).sharedMemPerBlock;
79 size_t getMaxSharedMemPerBlockCurrentDevice() {
80 return getMaxSharedMemPerBlock(getCurrentDevice());
83 int getDeviceForAddress(
const void* p) {
88 cudaPointerAttributes att;
89 cudaError_t err = cudaPointerGetAttributes(&att, p);
90 FAISS_ASSERT(err == cudaSuccess ||
91 err == cudaErrorInvalidValue);
93 if (err == cudaErrorInvalidValue) {
95 err = cudaGetLastError();
96 FAISS_ASSERT(err == cudaErrorInvalidValue);
98 }
else if (att.memoryType == cudaMemoryTypeHost) {
105 bool getFullUnifiedMemSupport(
int device) {
106 const auto& prop = getDeviceProperties(device);
107 return (prop.major >= 6);
110 bool getFullUnifiedMemSupportCurrentDevice() {
111 return getFullUnifiedMemSupport(getCurrentDevice());
114 DeviceScope::DeviceScope(
int device) {
115 prevDevice_ = getCurrentDevice();
117 if (prevDevice_ != device) {
118 setCurrentDevice(device);
124 DeviceScope::~DeviceScope() {
125 if (prevDevice_ != -1) {
126 setCurrentDevice(prevDevice_);
130 CublasHandleScope::CublasHandleScope() {
131 auto blasStatus = cublasCreate(&blasHandle_);
132 FAISS_ASSERT(blasStatus == CUBLAS_STATUS_SUCCESS);
135 CublasHandleScope::~CublasHandleScope() {
136 auto blasStatus = cublasDestroy(blasHandle_);
137 FAISS_ASSERT(blasStatus == CUBLAS_STATUS_SUCCESS);
142 CUDA_VERIFY(cudaEventCreateWithFlags(&event_, cudaEventDisableTiming));
143 CUDA_VERIFY(cudaEventRecord(event_, stream));
147 : event_(std::move(event.event_)) {
151 CudaEvent::~CudaEvent() {
153 CUDA_VERIFY(cudaEventDestroy(event_));
158 CudaEvent::operator=(CudaEvent&& event) noexcept {
159 event_ = std::move(event.event_);
167 CUDA_VERIFY(cudaStreamWaitEvent(stream, event_, 0));
172 CUDA_VERIFY(cudaEventSynchronize(event_));
void cpuWaitOnEvent()
Have the CPU wait for the completion of this event.
void streamWaitOnEvent(cudaStream_t stream)
Wait on this event in this stream.
CudaEvent(cudaStream_t stream)
Creates an event and records it in this stream.