11 #include "DeviceUtils.h"
12 #include "../../FaissAssert.h"
14 #include <unordered_map>
16 namespace faiss {
namespace gpu {
18 int getCurrentDevice() {
20 CUDA_VERIFY(cudaGetDevice(&dev));
21 FAISS_ASSERT(dev != -1);
26 void setCurrentDevice(
int device) {
27 CUDA_VERIFY(cudaSetDevice(device));
32 CUDA_VERIFY(cudaGetDeviceCount(&numDev));
33 FAISS_ASSERT(numDev != -1);
38 void synchronizeAllDevices() {
39 for (
int i = 0; i < getNumDevices(); ++i) {
42 CUDA_VERIFY(cudaDeviceSynchronize());
46 cudaDeviceProp& getDeviceProperties(
int device) {
47 static std::mutex mutex;
48 static std::unordered_map<int, cudaDeviceProp> properties;
50 std::lock_guard<std::mutex> guard(mutex);
52 auto it = properties.find(device);
53 if (it == properties.end()) {
55 CUDA_VERIFY(cudaGetDeviceProperties(&prop, device));
57 properties[device] = prop;
58 it = properties.find(device);
64 int getMaxThreads(
int device) {
65 return getDeviceProperties(device).maxThreadsPerBlock;
68 int getMaxThreadsCurrentDevice() {
69 return getMaxThreads(getCurrentDevice());
72 size_t getMaxSharedMemPerBlock(
int device) {
73 return getDeviceProperties(device).sharedMemPerBlock;
76 size_t getMaxSharedMemPerBlockCurrentDevice() {
77 return getMaxSharedMemPerBlock(getCurrentDevice());
80 int getDeviceForAddress(
const void* p) {
85 cudaPointerAttributes att;
86 cudaError_t err = cudaPointerGetAttributes(&att, p);
87 FAISS_ASSERT(err == cudaSuccess ||
88 err == cudaErrorInvalidValue);
90 if (err == cudaErrorInvalidValue) {
92 err = cudaGetLastError();
93 FAISS_ASSERT(err == cudaErrorInvalidValue);
95 }
else if (att.memoryType == cudaMemoryTypeHost) {
102 bool getFullUnifiedMemSupport(
int device) {
103 const auto& prop = getDeviceProperties(device);
104 return (prop.major >= 6);
107 bool getFullUnifiedMemSupportCurrentDevice() {
108 return getFullUnifiedMemSupport(getCurrentDevice());
111 DeviceScope::DeviceScope(
int device) {
112 prevDevice_ = getCurrentDevice();
114 if (prevDevice_ != device) {
115 setCurrentDevice(device);
121 DeviceScope::~DeviceScope() {
122 if (prevDevice_ != -1) {
123 setCurrentDevice(prevDevice_);
127 CublasHandleScope::CublasHandleScope() {
128 auto blasStatus = cublasCreate(&blasHandle_);
129 FAISS_ASSERT(blasStatus == CUBLAS_STATUS_SUCCESS);
132 CublasHandleScope::~CublasHandleScope() {
133 auto blasStatus = cublasDestroy(blasHandle_);
134 FAISS_ASSERT(blasStatus == CUBLAS_STATUS_SUCCESS);
139 CUDA_VERIFY(cudaEventCreateWithFlags(&event_, cudaEventDisableTiming));
140 CUDA_VERIFY(cudaEventRecord(event_, stream));
144 : event_(std::move(event.event_)) {
148 CudaEvent::~CudaEvent() {
150 CUDA_VERIFY(cudaEventDestroy(event_));
155 CudaEvent::operator=(CudaEvent&& event) noexcept {
156 event_ = std::move(event.event_);
164 CUDA_VERIFY(cudaStreamWaitEvent(stream, event_, 0));
169 CUDA_VERIFY(cudaEventSynchronize(event_));
void cpuWaitOnEvent()
Have the CPU wait for the completion of this event.
void streamWaitOnEvent(cudaStream_t stream)
Wait on this event in this stream.
CudaEvent(cudaStream_t stream)
Creates an event and records it in this stream.