175 lines
5.6 KiB
Python
175 lines
5.6 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import os
|
|
import numpy as np
|
|
import faiss
|
|
from typing import List
|
|
import random
|
|
import logging
|
|
from functools import lru_cache
|
|
|
|
|
|
def create_dataset_from_oivf_config(cfg, ds_name):
|
|
normalise = cfg["normalise"] if "normalise" in cfg else False
|
|
return MultiFileVectorDataset(
|
|
cfg["datasets"][ds_name]["root"],
|
|
[
|
|
FileDescriptor(
|
|
f["name"], f["format"], np.dtype(f["dtype"]), f["size"]
|
|
)
|
|
for f in cfg["datasets"][ds_name]["files"]
|
|
],
|
|
cfg["d"],
|
|
normalise,
|
|
cfg["datasets"][ds_name]["size"],
|
|
)
|
|
|
|
|
|
@lru_cache(maxsize=100)
|
|
def _memmap_vecs(
|
|
file_name: str, format: str, dtype: np.dtype, size: int, d: int
|
|
) -> np.array:
|
|
"""
|
|
If the file is in raw format, the file size will
|
|
be divisible by the dimensionality and by the size
|
|
of the data type.
|
|
Otherwise,the file contains a header and we assume
|
|
it is of .npy type. It the returns the memmapped file.
|
|
"""
|
|
|
|
assert os.path.exists(file_name), f"file does not exist {file_name}"
|
|
if format == "raw":
|
|
fl = os.path.getsize(file_name)
|
|
nb = fl // d // dtype.itemsize
|
|
assert nb == size, f"{nb} is different than config's {size}"
|
|
assert fl == d * dtype.itemsize * nb # no header
|
|
return np.memmap(file_name, shape=(nb, d), dtype=dtype, mode="r")
|
|
elif format == "npy":
|
|
vecs = np.load(file_name, mmap_mode="r")
|
|
assert vecs.shape[0] == size, f"size:{size},shape {vecs.shape[0]}"
|
|
assert vecs.shape[1] == d
|
|
assert vecs.dtype == dtype
|
|
return vecs
|
|
else:
|
|
ValueError("The file cannot be loaded in the current format.")
|
|
|
|
|
|
class FileDescriptor:
|
|
def __init__(self, name: str, format: str, dtype: np.dtype, size: int):
|
|
self.name = name
|
|
self.format = format
|
|
self.dtype = dtype
|
|
self.size = size
|
|
|
|
|
|
class MultiFileVectorDataset:
|
|
def __init__(
|
|
self,
|
|
root: str,
|
|
file_descriptors: List[FileDescriptor],
|
|
d: int,
|
|
normalize: bool,
|
|
size: int,
|
|
):
|
|
assert os.path.exists(root)
|
|
self.root = root
|
|
self.file_descriptors = file_descriptors
|
|
self.d = d
|
|
self.normalize = normalize
|
|
self.size = size
|
|
self.file_offsets = [0]
|
|
t = 0
|
|
for f in self.file_descriptors:
|
|
xb = _memmap_vecs(
|
|
f"{self.root}/{f.name}", f.format, f.dtype, f.size, self.d
|
|
)
|
|
t += xb.shape[0]
|
|
self.file_offsets.append(t)
|
|
assert (
|
|
t == self.size
|
|
), "the sum of num of embeddings per file!=total num of embeddings"
|
|
|
|
def iterate(self, start: int, batch_size: int, dt: np.dtype):
|
|
buffer = np.empty(shape=(batch_size, self.d), dtype=dt)
|
|
rem = 0
|
|
for f in self.file_descriptors:
|
|
if start >= f.size:
|
|
start -= f.size
|
|
continue
|
|
logging.info(f"processing: {f.name}...")
|
|
xb = _memmap_vecs(
|
|
f"{self.root}/{f.name}",
|
|
f.format,
|
|
f.dtype,
|
|
f.size,
|
|
self.d,
|
|
)
|
|
if start > 0:
|
|
xb = xb[start:]
|
|
start = 0
|
|
req = min(batch_size - rem, xb.shape[0])
|
|
buffer[rem:rem + req] = xb[:req]
|
|
rem += req
|
|
if rem == batch_size:
|
|
if self.normalize:
|
|
faiss.normalize_L2(buffer)
|
|
yield buffer.copy()
|
|
rem = 0
|
|
for i in range(req, xb.shape[0], batch_size):
|
|
j = i + batch_size
|
|
if j <= xb.shape[0]:
|
|
tmp = xb[i:j].astype(dt)
|
|
if self.normalize:
|
|
faiss.normalize_L2(tmp)
|
|
yield tmp
|
|
else:
|
|
rem = xb.shape[0] - i
|
|
buffer[:rem] = xb[i:j]
|
|
if rem > 0:
|
|
tmp = buffer[:rem]
|
|
if self.normalize:
|
|
faiss.normalize_L2(tmp)
|
|
yield tmp
|
|
|
|
def get(self, idx: List[int]):
|
|
n = len(idx)
|
|
fidx = np.searchsorted(self.file_offsets, idx, "right")
|
|
res = np.empty(shape=(len(idx), self.d), dtype=np.float32)
|
|
for r, id, fid in zip(range(n), idx, fidx):
|
|
assert fid > 0 and fid <= len(self.file_descriptors), f"{fid}"
|
|
f = self.file_descriptors[fid - 1]
|
|
# deferring normalization until after reading the vec
|
|
vecs = _memmap_vecs(
|
|
f"{self.root}/{f.name}", f.format, f.dtype, f.size, self.d
|
|
)
|
|
i = id - self.file_offsets[fid - 1]
|
|
assert i >= 0 and i < vecs.shape[0]
|
|
res[r, :] = vecs[i] # TODO: find a faster way
|
|
if self.normalize:
|
|
faiss.normalize_L2(res)
|
|
return res
|
|
|
|
def sample(self, n, idx_fn, vecs_fn):
|
|
if vecs_fn and os.path.exists(vecs_fn):
|
|
vecs = np.load(vecs_fn)
|
|
assert vecs.shape == (n, self.d)
|
|
return vecs
|
|
if idx_fn and os.path.exists(idx_fn):
|
|
idx = np.load(idx_fn)
|
|
assert idx.size == n
|
|
else:
|
|
idx = np.array(sorted(random.sample(range(self.size), n)))
|
|
if idx_fn:
|
|
np.save(idx_fn, idx)
|
|
vecs = self.get(idx)
|
|
if vecs_fn:
|
|
np.save(vecs_fn, vecs)
|
|
return vecs
|
|
|
|
def get_first_n(self, n, dt):
|
|
assert n <= self.size
|
|
return next(self.iterate(0, n, dt))
|