faiss/demos/demo_distributed_kmeans_tor...

174 lines
5.3 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
import torch
import torch.distributed
import faiss
import faiss.contrib.torch_utils
from faiss.contrib.torch import clustering
from faiss.contrib import datasets
class DatasetAssignDistributedGPU(clustering.DatasetAssign):
"""
There is one instance per worker, each worker has a dataset shard.
The non-master workers do not run through the k-means function, so some
code has run it to keep the workers in sync.
"""
def __init__(self, res, x, rank, nproc):
clustering.DatasetAssign.__init__(self, x)
self.res = res
self.rank = rank
self.nproc = nproc
self.device = x.device
n = len(x)
sizes = torch.zeros(nproc, device=self.device, dtype=torch.int64)
sizes[rank] = n
torch.distributed.all_gather(
[sizes[i:i + 1] for i in range(nproc)], sizes[rank:rank + 1])
self.sizes = sizes.cpu().numpy()
# begin & end of each shard
self.cs = np.zeros(nproc + 1, dtype='int64')
self.cs[1:] = np.cumsum(self.sizes)
def count(self):
return int(self.sizes.sum())
def int_to_slaves(self, i):
" broadcast an int to all workers "
rank = self.rank
tab = torch.zeros(1, device=self.device, dtype=torch.int64)
if rank == 0:
tab[0] = i
else:
assert i is None
torch.distributed.broadcast(tab, 0)
return tab.item()
def get_subset(self, indices):
rank = self.rank
assert rank == 0 or indices is None
len_indices = self.int_to_slaves(len(indices) if rank == 0 else None)
if rank == 0:
indices = torch.from_numpy(indices).to(self.device)
else:
indices = torch.zeros(
len_indices, dtype=torch.int64, device=self.device)
torch.distributed.broadcast(indices, 0)
# select subset of indices
i0, i1 = self.cs[rank], self.cs[rank + 1]
mask = torch.logical_and(indices < i1, indices >= i0)
output = torch.zeros(
len_indices, self.x.shape[1],
dtype=self.x.dtype, device=self.device)
output[mask] = self.x[indices[mask] - i0]
torch.distributed.reduce(output, 0) # sum
if rank == 0:
return output
else:
return None
def perform_search(self, centroids):
assert False, "shoudl not be called"
def assign_to(self, centroids, weights=None):
assert weights is None
rank, nproc = self.rank, self.nproc
assert rank == 0 or centroids is None
nc = self.int_to_slaves(len(centroids) if rank == 0 else None)
if rank != 0:
centroids = torch.zeros(
nc, self.x.shape[1], dtype=self.x.dtype, device=self.device)
torch.distributed.broadcast(centroids, 0)
# perform search
D, I = faiss.knn_gpu(
self.res, self.x, centroids, 1, device=self.device.index)
I = I.ravel()
D = D.ravel()
sum_per_centroid = torch.zeros_like(centroids)
if weights is None:
sum_per_centroid.index_add_(0, I, self.x)
else:
sum_per_centroid.index_add_(0, I, self.x * weights[:, None])
torch.distributed.reduce(sum_per_centroid, 0)
if rank == 0:
# gather deos not support tensors of different sizes
# should be implemented with point-to-point communication
assert np.all(self.sizes == self.sizes[0])
device = self.device
all_I = torch.zeros(self.count(), dtype=I.dtype, device=device)
all_D = torch.zeros(self.count(), dtype=D.dtype, device=device)
torch.distributed.gather(
I, [all_I[self.cs[r]:self.cs[r + 1]] for r in range(nproc)],
dst=0,
)
torch.distributed.gather(
D, [all_D[self.cs[r]:self.cs[r + 1]] for r in range(nproc)],
dst=0,
)
return all_I.cpu().numpy(), all_D, sum_per_centroid
else:
torch.distributed.gather(I, None, dst=0)
torch.distributed.gather(D, None, dst=0)
return None
if __name__ == "__main__":
torch.distributed.init_process_group(
backend="nccl",
)
rank = torch.distributed.get_rank()
nproc = torch.distributed.get_world_size()
# current version does only support shards of the same size
ds = datasets.SyntheticDataset(32, 10000, 0, 0, seed=1234 + rank)
x = ds.get_train()
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
x = torch.from_numpy(x).to(device)
res = faiss.StandardGpuResources()
da = DatasetAssignDistributedGPU(res, x, rank, nproc)
k = 1000
niter = 25
if rank == 0:
print(f"sizes = {da.sizes}")
centroids, iteration_stats = clustering.kmeans(
k, da, niter=niter, return_stats=True)
print("clusters:", centroids.cpu().numpy())
else:
# make sure the iterations are aligned with master
da.get_subset(None)
for _ in range(niter):
da.assign_to(None)
torch.distributed.barrier()
print("Done")