# encoding: utf-8 # copy from: https://github.com/open-mmlab/OpenUnReID/blob/66bb2ae0b00575b80fbe8915f4d4f4739cc21206/openunreid/core/utils/faiss_utils.py import faiss import torch def swig_ptr_from_FloatTensor(x): assert x.is_contiguous() assert x.dtype == torch.float32 return faiss.cast_integer_to_float_ptr( x.storage().data_ptr() + x.storage_offset() * 4 ) def swig_ptr_from_LongTensor(x): assert x.is_contiguous() assert x.dtype == torch.int64, "dtype=%s" % x.dtype return faiss.cast_integer_to_long_ptr( x.storage().data_ptr() + x.storage_offset() * 8 ) def search_index_pytorch(index, x, k, D=None, I=None): """call the search function of an index with pytorch tensor I/O (CPU and GPU supported)""" assert x.is_contiguous() n, d = x.size() assert d == index.d if D is None: D = torch.empty((n, k), dtype=torch.float32, device=x.device) else: assert D.size() == (n, k) if I is None: I = torch.empty((n, k), dtype=torch.int64, device=x.device) else: assert I.size() == (n, k) torch.cuda.synchronize() xptr = swig_ptr_from_FloatTensor(x) Iptr = swig_ptr_from_LongTensor(I) Dptr = swig_ptr_from_FloatTensor(D) index.search_c(n, xptr, k, Dptr, Iptr) torch.cuda.synchronize() return D, I def search_raw_array_pytorch(res, xb, xq, k, D=None, I=None, metric=faiss.METRIC_L2): assert xb.device == xq.device nq, d = xq.size() if xq.is_contiguous(): xq_row_major = True elif xq.t().is_contiguous(): xq = xq.t() # I initially wrote xq:t(), Lua is still haunting me :-) xq_row_major = False else: raise TypeError("matrix should be row or column-major") xq_ptr = swig_ptr_from_FloatTensor(xq) nb, d2 = xb.size() assert d2 == d if xb.is_contiguous(): xb_row_major = True elif xb.t().is_contiguous(): xb = xb.t() xb_row_major = False else: raise TypeError("matrix should be row or column-major") xb_ptr = swig_ptr_from_FloatTensor(xb) if D is None: D = torch.empty(nq, k, device=xb.device, dtype=torch.float32) else: assert D.shape == (nq, k) assert D.device == xb.device if I is None: I = torch.empty(nq, k, device=xb.device, dtype=torch.int64) else: assert I.shape == (nq, k) assert I.device == xb.device D_ptr = swig_ptr_from_FloatTensor(D) I_ptr = swig_ptr_from_LongTensor(I) faiss.bruteForceKnn( res, metric, xb_ptr, xb_row_major, nb, xq_ptr, xq_row_major, nq, d, k, D_ptr, I_ptr, ) return D, I def index_init_gpu(ngpus, feat_dim): flat_config = [] for i in range(ngpus): cfg = faiss.GpuIndexFlatConfig() cfg.useFloat16 = False cfg.device = i flat_config.append(cfg) res = [faiss.StandardGpuResources() for i in range(ngpus)] indexes = [ faiss.GpuIndexFlatL2(res[i], feat_dim, flat_config[i]) for i in range(ngpus) ] index = faiss.IndexShards(feat_dim) for sub_index in indexes: index.add_shard(sub_index) index.reset() return index def index_init_cpu(feat_dim): return faiss.IndexFlatL2(feat_dim)