mirror of https://github.com/JDAI-CV/fast-reid.git
128 lines
3.3 KiB
Python
128 lines
3.3 KiB
Python
# encoding: utf-8
|
|
# copy from: https://github.com/open-mmlab/OpenUnReID/blob/66bb2ae0b00575b80fbe8915f4d4f4739cc21206/openunreid/core/utils/faiss_utils.py
|
|
|
|
import faiss
|
|
import torch
|
|
|
|
|
|
def swig_ptr_from_FloatTensor(x):
|
|
assert x.is_contiguous()
|
|
assert x.dtype == torch.float32
|
|
return faiss.cast_integer_to_float_ptr(
|
|
x.storage().data_ptr() + x.storage_offset() * 4
|
|
)
|
|
|
|
|
|
def swig_ptr_from_LongTensor(x):
|
|
assert x.is_contiguous()
|
|
assert x.dtype == torch.int64, "dtype=%s" % x.dtype
|
|
return faiss.cast_integer_to_long_ptr(
|
|
x.storage().data_ptr() + x.storage_offset() * 8
|
|
)
|
|
|
|
|
|
def search_index_pytorch(index, x, k, D=None, I=None):
|
|
"""call the search function of an index with pytorch tensor I/O (CPU
|
|
and GPU supported)"""
|
|
assert x.is_contiguous()
|
|
n, d = x.size()
|
|
assert d == index.d
|
|
|
|
if D is None:
|
|
D = torch.empty((n, k), dtype=torch.float32, device=x.device)
|
|
else:
|
|
assert D.size() == (n, k)
|
|
|
|
if I is None:
|
|
I = torch.empty((n, k), dtype=torch.int64, device=x.device)
|
|
else:
|
|
assert I.size() == (n, k)
|
|
torch.cuda.synchronize()
|
|
xptr = swig_ptr_from_FloatTensor(x)
|
|
Iptr = swig_ptr_from_LongTensor(I)
|
|
Dptr = swig_ptr_from_FloatTensor(D)
|
|
index.search_c(n, xptr, k, Dptr, Iptr)
|
|
torch.cuda.synchronize()
|
|
return D, I
|
|
|
|
|
|
def search_raw_array_pytorch(res, xb, xq, k, D=None, I=None, metric=faiss.METRIC_L2):
|
|
assert xb.device == xq.device
|
|
|
|
nq, d = xq.size()
|
|
if xq.is_contiguous():
|
|
xq_row_major = True
|
|
elif xq.t().is_contiguous():
|
|
xq = xq.t() # I initially wrote xq:t(), Lua is still haunting me :-)
|
|
xq_row_major = False
|
|
else:
|
|
raise TypeError("matrix should be row or column-major")
|
|
|
|
xq_ptr = swig_ptr_from_FloatTensor(xq)
|
|
|
|
nb, d2 = xb.size()
|
|
assert d2 == d
|
|
if xb.is_contiguous():
|
|
xb_row_major = True
|
|
elif xb.t().is_contiguous():
|
|
xb = xb.t()
|
|
xb_row_major = False
|
|
else:
|
|
raise TypeError("matrix should be row or column-major")
|
|
xb_ptr = swig_ptr_from_FloatTensor(xb)
|
|
|
|
if D is None:
|
|
D = torch.empty(nq, k, device=xb.device, dtype=torch.float32)
|
|
else:
|
|
assert D.shape == (nq, k)
|
|
assert D.device == xb.device
|
|
|
|
if I is None:
|
|
I = torch.empty(nq, k, device=xb.device, dtype=torch.int64)
|
|
else:
|
|
assert I.shape == (nq, k)
|
|
assert I.device == xb.device
|
|
|
|
D_ptr = swig_ptr_from_FloatTensor(D)
|
|
I_ptr = swig_ptr_from_LongTensor(I)
|
|
|
|
faiss.bruteForceKnn(
|
|
res,
|
|
metric,
|
|
xb_ptr,
|
|
xb_row_major,
|
|
nb,
|
|
xq_ptr,
|
|
xq_row_major,
|
|
nq,
|
|
d,
|
|
k,
|
|
D_ptr,
|
|
I_ptr,
|
|
)
|
|
|
|
return D, I
|
|
|
|
|
|
def index_init_gpu(ngpus, feat_dim):
|
|
flat_config = []
|
|
for i in range(ngpus):
|
|
cfg = faiss.GpuIndexFlatConfig()
|
|
cfg.useFloat16 = False
|
|
cfg.device = i
|
|
flat_config.append(cfg)
|
|
|
|
res = [faiss.StandardGpuResources() for i in range(ngpus)]
|
|
indexes = [
|
|
faiss.GpuIndexFlatL2(res[i], feat_dim, flat_config[i]) for i in range(ngpus)
|
|
]
|
|
index = faiss.IndexShards(feat_dim)
|
|
for sub_index in indexes:
|
|
index.add_shard(sub_index)
|
|
index.reset()
|
|
return index
|
|
|
|
|
|
def index_init_cpu(feat_dim):
|
|
return faiss.IndexFlatL2(feat_dim)
|