Support cosine distance for training vectors (#4227)

Summary:

same as title

Differential Revision: D70724590
export-D70724590
Junjie Qi 2025-03-06 16:42:20 -08:00 committed by Facebook GitHub Bot
parent c109174198
commit 4bf99c3171
2 changed files with 4 additions and 0 deletions

View File

@ -106,6 +106,8 @@ class DatasetDescriptor:
# desc_name
desc_name: Optional[str] = None
normalize_L2: bool = False
def __hash__(self):
return hash(self.get_filename())

View File

@ -1138,6 +1138,8 @@ class IndexFromFactory(Index):
return None, None, ""
logger.info(f"assemble, train {self.factory}")
xt = self.io.get_dataset(self.training_vectors)
if self.training_vectors.normalize_L2:
faiss.normalize_L2(xt)
_, t, _ = timer("train", lambda: codec.train(xt), once=True)
t_aggregate += t