/** * Copyright (c) Facebook, Inc. and its affiliates. * * This source code is licensed under the MIT license found in the * LICENSE file in the root directory of this source tree. */ #include #include #include #include #include #include #include #include namespace faiss { namespace gpu { template void bfKnnConvert(GpuResources* resources, const GpuDistanceParams& args) { auto device = getCurrentDevice(); auto stream = resources->getDefaultStreamCurrentDevice(); auto& mem = resources->getMemoryManagerCurrentDevice(); auto tVectors = toDevice(resources, device, const_cast(reinterpret_cast(args.vectors)), stream, {args.vectorsRowMajor ? args.numVectors : args.dims, args.vectorsRowMajor ? args.dims : args.numVectors}); auto tQueries = toDevice(resources, device, const_cast(reinterpret_cast(args.queries)), stream, {args.queriesRowMajor ? args.numQueries : args.dims, args.queriesRowMajor ? args.dims : args.numQueries}); DeviceTensor tVectorNorms; if (args.vectorNorms) { tVectorNorms = toDevice(resources, device, const_cast(args.vectorNorms), stream, {args.numVectors}); } auto tOutDistances = toDevice(resources, device, args.outDistances, stream, {args.numQueries, args.k}); // The brute-force API only supports an interface for integer indices DeviceTensor tOutIntIndices(mem, {args.numQueries, args.k}, stream); // Since we've guaranteed that all arguments are on device, call the // implementation bfKnnOnDevice(resources, device, stream, tVectors, args.vectorsRowMajor, args.vectorNorms ? &tVectorNorms : nullptr, tQueries, args.queriesRowMajor, args.k, args.metric, args.metricArg, tOutDistances, tOutIntIndices, args.ignoreOutDistances); // Convert and copy int indices out auto tOutIndices = toDevice(resources, device, args.outIndices, stream, {args.numQueries, args.k}); // Convert int to idx_t convertTensor(stream, tOutIntIndices, tOutIndices); // Copy back if necessary fromDevice(tOutDistances, args.outDistances, stream); fromDevice(tOutIndices, args.outIndices, stream); } void bfKnn(GpuResources* resources, const GpuDistanceParams& args) { // For now, both vectors and queries must be of the same data type FAISS_THROW_IF_NOT_MSG( args.vectorType == args.queryType, "limitation: both vectorType and queryType must currently " "be the same (F32 or F16"); if (args.vectorType == DistanceDataType::F32) { bfKnnConvert(resources, args); } else if (args.vectorType == DistanceDataType::F16) { bfKnnConvert(resources, args); } else { FAISS_THROW_MSG("unknown vectorType"); } } // legacy version void bruteForceKnn(GpuResources* resources, faiss::MetricType metric, // A region of memory size numVectors x dims, with dims // innermost const float* vectors, bool vectorsRowMajor, int numVectors, // A region of memory size numQueries x dims, with dims // innermost const float* queries, bool queriesRowMajor, int numQueries, int dims, int k, // A region of memory size numQueries x k, with k // innermost float* outDistances, // A region of memory size numQueries x k, with k // innermost faiss::Index::idx_t* outIndices) { std::cerr << "bruteForceKnn is deprecated; call bfKnn instead" << std::endl; GpuDistanceParams args; args.metric = metric; args.k = k; args.dims = dims; args.vectors = vectors; args.vectorsRowMajor = vectorsRowMajor; args.numVectors = numVectors; args.queries = queries; args.queriesRowMajor = queriesRowMajor; args.numQueries = numQueries; args.outDistances = outDistances; args.outIndices = outIndices; bfKnn(resources, args); } } } // namespace