faiss/contrib/torch/clustering.py

60 lines
1.6 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
This contrib module contains Pytorch code for k-means clustering
"""
import faiss
import faiss.contrib.torch_utils
import torch
# the kmeans can produce both torch and numpy centroids
from faiss.contrib.clustering import kmeans
class DatasetAssign:
"""Wrapper for a tensor that offers a function to assign the vectors
to centroids. All other implementations offer the same interface"""
def __init__(self, x):
self.x = x
def count(self):
return self.x.shape[0]
def dim(self):
return self.x.shape[1]
def get_subset(self, indices):
return self.x[indices]
def perform_search(self, centroids):
return faiss.knn(self.x, centroids, 1)
def assign_to(self, centroids, weights=None):
D, I = self.perform_search(centroids)
I = I.ravel()
D = D.ravel()
nc, d = centroids.shape
sum_per_centroid = torch.zeros_like(centroids)
if weights is None:
sum_per_centroid.index_add_(0, I, self.x)
else:
sum_per_centroid.index_add_(0, I, self.x * weights[:, None])
# the indices are still in numpy.
return I.cpu().numpy(), D, sum_per_centroid
class DatasetAssignGPU(DatasetAssign):
def __init__(self, res, x):
DatasetAssign.__init__(self, x)
self.res = res
def perform_search(self, centroids):
return faiss.knn_gpu(self.res, self.x, centroids, 1)