faiss/benchs/distributed_ondisk/distributed_kmeans.py
Matthijs Douze b8fe92dfee contrib clustering module (#2217)
Summary:
Pull Request resolved: https://github.com/facebookresearch/faiss/pull/2217

This diff introduces a new Faiss contrib module that contains:
- generic k-means implemented in python (was in distributed_ondisk)
- the two-level clustering code, including a simple function that runs it on a Faiss IVF index.
- sparse clustering code (new)

The main idea is that that code is often re-used so better have it in contrib.

Reviewed By: beauby

Differential Revision: D34170932

fbshipit-source-id: cc297cc56d241b5ef421500ed410d8e2be0f1b77
2022-02-28 14:18:47 -08:00

241 lines
7.3 KiB
Python
Executable File

#! /usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Simple distributed kmeans implementation Relies on an abstraction
for the training matrix, that can be sharded over several machines.
"""
import os
import sys
import argparse
import numpy as np
import faiss
from multiprocessing.dummy import Pool as ThreadPool
from faiss.contrib import rpc
from faiss.contrib.datasets import SyntheticDataset
from faiss.contrib.vecs_io import bvecs_mmap, fvecs_mmap
from faiss.contrib.clustering import DatasetAssign, DatasetAssignGPU, kmeans
class DatasetAssignDispatch:
"""dispatches to several other DatasetAssigns and combines the
results"""
def __init__(self, xes, in_parallel):
self.xes = xes
self.d = xes[0].dim()
if not in_parallel:
self.imap = map
else:
self.pool = ThreadPool(len(self.xes))
self.imap = self.pool.imap
self.sizes = list(map(lambda x: x.count(), self.xes))
self.cs = np.cumsum([0] + self.sizes)
def count(self):
return self.cs[-1]
def dim(self):
return self.d
def get_subset(self, indices):
res = np.zeros((len(indices), self.d), dtype='float32')
nos = np.searchsorted(self.cs[1:], indices, side='right')
def handle(i):
mask = nos == i
sub_indices = indices[mask] - self.cs[i]
subset = self.xes[i].get_subset(sub_indices)
res[mask] = subset
list(self.imap(handle, range(len(self.xes))))
return res
def assign_to(self, centroids, weights=None):
src = self.imap(
lambda x: x.assign_to(centroids, weights),
self.xes
)
I = []
D = []
sum_per_centroid = None
for Ii, Di, sum_per_centroid_i in src:
I.append(Ii)
D.append(Di)
if sum_per_centroid is None:
sum_per_centroid = sum_per_centroid_i
else:
sum_per_centroid += sum_per_centroid_i
return np.hstack(I), np.hstack(D), sum_per_centroid
class AssignServer(rpc.Server):
""" Assign version that can be exposed via RPC """
def __init__(self, s, assign, log_prefix=''):
rpc.Server.__init__(self, s, log_prefix=log_prefix)
self.assign = assign
def __getattr__(self, f):
return getattr(self.assign, f)
def do_test(todo):
testdata = '/datasets01_101/simsearch/041218/bigann/bigann_learn.bvecs'
if os.path.exists(testdata):
x = bvecs_mmap(testdata)
else:
print("using synthetic dataset")
ds = SyntheticDataset(128, 100000, 0, 0)
x = ds.get_train()
# bad distribution to stress-test split code
xx = x[:100000].copy()
xx[:50000] = x[0]
todo = sys.argv[1:]
if "0" in todo:
# reference C++ run
km = faiss.Kmeans(x.shape[1], 1000, niter=20, verbose=True)
km.train(xx.astype('float32'))
if "1" in todo:
# using the Faiss c++ implementation
data = DatasetAssign(xx)
kmeans(1000, data, 20)
if "2" in todo:
# use the dispatch object (on local datasets)
data = DatasetAssignDispatch([
DatasetAssign(xx[20000 * i : 20000 * (i + 1)])
for i in range(5)
], False
)
kmeans(1000, data, 20)
if "3" in todo:
# same, with GPU
ngpu = faiss.get_num_gpus()
print('using %d GPUs' % ngpu)
data = DatasetAssignDispatch([
DatasetAssignGPU(xx[100000 * i // ngpu: 100000 * (i + 1) // ngpu], i)
for i in range(ngpu)
], True
)
kmeans(1000, data, 20)
def main():
parser = argparse.ArgumentParser()
def aa(*args, **kwargs):
group.add_argument(*args, **kwargs)
group = parser.add_argument_group('general options')
aa('--test', default='', help='perform tests (comma-separated numbers)')
aa('--k', default=0, type=int, help='nb centroids')
aa('--seed', default=1234, type=int, help='random seed')
aa('--niter', default=20, type=int, help='nb iterations')
aa('--gpu', default=-2, type=int, help='GPU to use (-2:none, -1: all)')
group = parser.add_argument_group('I/O options')
aa('--indata', default='',
help='data file to load (supported formats fvecs, bvecs, npy')
aa('--i0', default=0, type=int, help='first vector to keep')
aa('--i1', default=-1, type=int, help='last vec to keep + 1')
aa('--out', default='', help='file to store centroids')
aa('--store_each_iteration', default=False, action='store_true',
help='store centroid checkpoints')
group = parser.add_argument_group('server options')
aa('--server', action='store_true', default=False, help='run server')
aa('--port', default=12345, type=int, help='server port')
aa('--when_ready', default=None, help='store host:port to this file when ready')
aa('--ipv4', default=False, action='store_true', help='force ipv4')
group = parser.add_argument_group('client options')
aa('--client', action='store_true', default=False, help='run client')
aa('--servers', default='', help='list of server:port separated by spaces')
args = parser.parse_args()
if args.test:
do_test(args.test.split(','))
return
# prepare data matrix (either local or remote)
if args.indata:
print('loading ', args.indata)
if args.indata.endswith('.bvecs'):
x = bvecs_mmap(args.indata)
elif args.indata.endswith('.fvecs'):
x = fvecs_mmap(args.indata)
elif args.indata.endswith('.npy'):
x = np.load(args.indata, mmap_mode='r')
else:
raise AssertionError
if args.i1 == -1:
args.i1 = len(x)
x = x[args.i0:args.i1]
if args.gpu == -2:
data = DatasetAssign(x)
else:
print('moving to GPU')
data = DatasetAssignGPU(x, args.gpu)
elif args.client:
print('connecting to servers')
def connect_client(hostport):
host, port = hostport.split(':')
port = int(port)
print('connecting %s:%d' % (host, port))
client = rpc.Client(host, port, v6=not args.ipv4)
print('client %s:%d ready' % (host, port))
return client
hostports = args.servers.strip().split(' ')
# pool = ThreadPool(len(hostports))
data = DatasetAssignDispatch(
list(map(connect_client, hostports)),
True
)
else:
raise AssertionError
if args.server:
print('starting server')
log_prefix = f"{rpc.socket.gethostname()}:{args.port}"
rpc.run_server(
lambda s: AssignServer(s, data, log_prefix=log_prefix),
args.port, report_to_file=args.when_ready,
v6=not args.ipv4)
else:
print('running kmeans')
centroids = kmeans(args.k, data, niter=args.niter, seed=args.seed,
checkpoint=args.out if args.store_each_iteration else None)
if args.out != '':
print('writing centroids to', args.out)
np.save(args.out, centroids)
if __name__ == '__main__':
main()