/** * Copyright (c) Facebook, Inc. and its affiliates. * * This source code is licensed under the MIT license found in the * LICENSE file in the root directory of this source tree. */ // -*- c++ -*- /* * implementation of Hyper-parameter auto-tuning */ #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace faiss { /*************************************************************** * index_factory ***************************************************************/ namespace { struct VTChain { std::vector chain; ~VTChain () { for (int i = 0; i < chain.size(); i++) { delete chain[i]; } } }; /// what kind of training does this coarse quantizer require? char get_trains_alone(const Index *coarse_quantizer) { return dynamic_cast(coarse_quantizer) ? 1 : dynamic_cast(coarse_quantizer) ? 2 : 0; } } Index *index_factory (int d, const char *description_in, MetricType metric) { FAISS_THROW_IF_NOT(metric == METRIC_L2 || metric == METRIC_INNER_PRODUCT); VTChain vts; Index *coarse_quantizer = nullptr; Index *index = nullptr; bool add_idmap = false; bool make_IndexRefineFlat = false; ScopeDeleter1 del_coarse_quantizer, del_index; char description[strlen(description_in) + 1]; char *ptr; memcpy (description, description_in, strlen(description_in) + 1); int64_t ncentroids = -1; bool use_2layer = false; for (char *tok = strtok_r (description, " ,", &ptr); tok; tok = strtok_r (nullptr, " ,", &ptr)) { int d_out, opq_M, nbit, M, M2, pq_m, ncent, r2; std::string stok(tok); nbit = 8; // to avoid mem leaks with exceptions: // do all tests before any instanciation VectorTransform *vt_1 = nullptr; Index *coarse_quantizer_1 = nullptr; Index *index_1 = nullptr; // VectorTransforms if (sscanf (tok, "PCA%d", &d_out) == 1) { vt_1 = new PCAMatrix (d, d_out); d = d_out; } else if (sscanf (tok, "PCAR%d", &d_out) == 1) { vt_1 = new PCAMatrix (d, d_out, 0, true); d = d_out; } else if (sscanf (tok, "RR%d", &d_out) == 1) { vt_1 = new RandomRotationMatrix (d, d_out); d = d_out; } else if (sscanf (tok, "PCAW%d", &d_out) == 1) { vt_1 = new PCAMatrix (d, d_out, -0.5, false); d = d_out; } else if (sscanf (tok, "PCAWR%d", &d_out) == 1) { vt_1 = new PCAMatrix (d, d_out, -0.5, true); d = d_out; } else if (sscanf (tok, "OPQ%d_%d", &opq_M, &d_out) == 2) { vt_1 = new OPQMatrix (d, opq_M, d_out); d = d_out; } else if (sscanf (tok, "OPQ%d", &opq_M) == 1) { vt_1 = new OPQMatrix (d, opq_M); } else if (sscanf (tok, "ITQ%d", &d_out) == 1) { vt_1 = new ITQTransform (d, d_out, true); d = d_out; } else if (stok == "ITQ") { vt_1 = new ITQTransform (d, d, false); } else if (sscanf (tok, "Pad%d", &d_out) == 1) { if (d_out > d) { vt_1 = new RemapDimensionsTransform (d, d_out, false); d = d_out; } } else if (stok == "L2norm") { vt_1 = new NormalizationTransform (d, 2.0); // coarse quantizers } else if (!coarse_quantizer && sscanf (tok, "IVF%ld_HNSW%d", &ncentroids, &M) == 2) { FAISS_THROW_IF_NOT (metric == METRIC_L2); coarse_quantizer_1 = new IndexHNSWFlat (d, M); } else if (!coarse_quantizer && sscanf (tok, "IVF%ld", &ncentroids) == 1) { if (metric == METRIC_L2) { coarse_quantizer_1 = new IndexFlatL2 (d); } else { coarse_quantizer_1 = new IndexFlatIP (d); } } else if (!coarse_quantizer && sscanf (tok, "IMI2x%d", &nbit) == 1) { FAISS_THROW_IF_NOT_MSG (metric == METRIC_L2, "MultiIndex not implemented for inner prod search"); coarse_quantizer_1 = new MultiIndexQuantizer (d, 2, nbit); ncentroids = 1 << (2 * nbit); } else if (!coarse_quantizer && sscanf (tok, "Residual%dx%d", &M, &nbit) == 2) { FAISS_THROW_IF_NOT_MSG (metric == METRIC_L2, "MultiIndex not implemented for inner prod search"); coarse_quantizer_1 = new MultiIndexQuantizer (d, M, nbit); ncentroids = int64_t(1) << (M * nbit); use_2layer = true; } else if (!coarse_quantizer && sscanf (tok, "Residual%ld", &ncentroids) == 1) { coarse_quantizer_1 = new IndexFlatL2 (d); use_2layer = true; } else if (stok == "IDMap") { add_idmap = true; // IVFs } else if (!index && (stok == "Flat" || stok == "FlatDedup")) { if (coarse_quantizer) { // if there was an IVF in front, then it is an IVFFlat IndexIVF *index_ivf = stok == "Flat" ? new IndexIVFFlat ( coarse_quantizer, d, ncentroids, metric) : new IndexIVFFlatDedup ( coarse_quantizer, d, ncentroids, metric); index_ivf->quantizer_trains_alone = get_trains_alone (coarse_quantizer); index_ivf->cp.spherical = metric == METRIC_INNER_PRODUCT; del_coarse_quantizer.release (); index_ivf->own_fields = true; index_1 = index_ivf; } else { FAISS_THROW_IF_NOT_MSG (stok != "FlatDedup", "dedup supported only for IVFFlat"); index_1 = new IndexFlat (d, metric); } } else if (!index && (stok == "SQ8" || stok == "SQ4" || stok == "SQ6" || stok == "SQfp16")) { ScalarQuantizer::QuantizerType qt = stok == "SQ8" ? ScalarQuantizer::QT_8bit : stok == "SQ6" ? ScalarQuantizer::QT_6bit : stok == "SQ4" ? ScalarQuantizer::QT_4bit : stok == "SQfp16" ? ScalarQuantizer::QT_fp16 : ScalarQuantizer::QT_4bit; if (coarse_quantizer) { FAISS_THROW_IF_NOT (!use_2layer); IndexIVFScalarQuantizer *index_ivf = new IndexIVFScalarQuantizer ( coarse_quantizer, d, ncentroids, qt, metric); index_ivf->quantizer_trains_alone = get_trains_alone (coarse_quantizer); del_coarse_quantizer.release (); index_ivf->own_fields = true; index_1 = index_ivf; } else { index_1 = new IndexScalarQuantizer (d, qt, metric); } } else if (!index && sscanf (tok, "PQ%d+%d", &M, &M2) == 2) { FAISS_THROW_IF_NOT_MSG(coarse_quantizer, "PQ with + works only with an IVF"); FAISS_THROW_IF_NOT_MSG(metric == METRIC_L2, "IVFPQR not implemented for inner product search"); IndexIVFPQR *index_ivf = new IndexIVFPQR ( coarse_quantizer, d, ncentroids, M, 8, M2, 8); index_ivf->quantizer_trains_alone = get_trains_alone (coarse_quantizer); del_coarse_quantizer.release (); index_ivf->own_fields = true; index_1 = index_ivf; } else if (!index && (sscanf (tok, "PQ%dx%d", &M, &nbit) == 2 || sscanf (tok, "PQ%d", &M) == 1 || sscanf (tok, "PQ%dnp", &M) == 1)) { bool do_polysemous_training = stok.find("np") == std::string::npos; if (coarse_quantizer) { if (!use_2layer) { IndexIVFPQ *index_ivf = new IndexIVFPQ ( coarse_quantizer, d, ncentroids, M, nbit); index_ivf->quantizer_trains_alone = get_trains_alone (coarse_quantizer); index_ivf->metric_type = metric; index_ivf->cp.spherical = metric == METRIC_INNER_PRODUCT; del_coarse_quantizer.release (); index_ivf->own_fields = true; index_ivf->do_polysemous_training = do_polysemous_training; index_1 = index_ivf; } else { Index2Layer *index_2l = new Index2Layer (coarse_quantizer, ncentroids, M, nbit); index_2l->q1.quantizer_trains_alone = get_trains_alone (coarse_quantizer); index_2l->q1.own_fields = true; index_1 = index_2l; } } else { IndexPQ *index_pq = new IndexPQ (d, M, nbit, metric); index_pq->do_polysemous_training = do_polysemous_training; index_1 = index_pq; } } else if (!index && sscanf (tok, "HNSW%d_%d+PQ%d", &M, &ncent, &pq_m) == 3) { Index * quant = new IndexFlatL2 (d); IndexHNSW2Level * hidx2l = new IndexHNSW2Level (quant, ncent, pq_m, M); Index2Layer * idx2l = dynamic_cast(hidx2l->storage); idx2l->q1.own_fields = true; index_1 = hidx2l; } else if (!index && sscanf (tok, "HNSW%d_2x%d+PQ%d", &M, &nbit, &pq_m) == 3) { Index * quant = new MultiIndexQuantizer (d, 2, nbit); IndexHNSW2Level * hidx2l = new IndexHNSW2Level (quant, 1 << (2 * nbit), pq_m, M); Index2Layer * idx2l = dynamic_cast(hidx2l->storage); idx2l->q1.own_fields = true; idx2l->q1.quantizer_trains_alone = 1; index_1 = hidx2l; } else if (!index && sscanf (tok, "HNSW%d_PQ%d", &M, &pq_m) == 2) { index_1 = new IndexHNSWPQ (d, pq_m, M); } else if (!index && sscanf (tok, "HNSW%d", &M) == 1) { index_1 = new IndexHNSWFlat (d, M); } else if (!index && sscanf (tok, "HNSW%d_SQ%d", &M, &pq_m) == 2 && pq_m == 8) { index_1 = new IndexHNSWSQ (d, ScalarQuantizer::QT_8bit, M); } else if (!index && (stok == "LSH" || stok == "LSHr" || stok == "LSHrt" || stok == "LSHt")) { bool rotate_data = strstr(tok, "r") != nullptr; bool train_thresholds = strstr(tok, "t") != nullptr; index_1 = new IndexLSH (d, d, rotate_data, train_thresholds); } else if (!index && sscanf (tok, "ZnLattice%dx%d_%d", &M, &r2, &nbit) == 3) { FAISS_THROW_IF_NOT(!coarse_quantizer); index_1 = new IndexLattice(d, M, nbit, r2); } else if (stok == "RFlat") { make_IndexRefineFlat = true; } else { FAISS_THROW_FMT( "could not parse token \"%s\" in %s\n", tok, description_in); } if (index_1 && add_idmap) { IndexIDMap *idmap = new IndexIDMap(index_1); del_index.set (idmap); idmap->own_fields = true; index_1 = idmap; add_idmap = false; } if (vt_1) { vts.chain.push_back (vt_1); } if (coarse_quantizer_1) { coarse_quantizer = coarse_quantizer_1; del_coarse_quantizer.set (coarse_quantizer); } if (index_1) { index = index_1; del_index.set (index); } } FAISS_THROW_IF_NOT_FMT(index, "description %s did not generate an index", description_in); // nothing can go wrong now del_index.release (); del_coarse_quantizer.release (); if (add_idmap) { fprintf(stderr, "index_factory: WARNING: " "IDMap option not used\n"); } if (vts.chain.size() > 0) { IndexPreTransform *index_pt = new IndexPreTransform (index); index_pt->own_fields = true; // add from back while (vts.chain.size() > 0) { index_pt->prepend_transform (vts.chain.back ()); vts.chain.pop_back (); } index = index_pt; } if (make_IndexRefineFlat) { IndexRefineFlat *index_rf = new IndexRefineFlat (index); index_rf->own_fields = true; index = index_rf; } return index; } IndexBinary *index_binary_factory(int d, const char *description) { IndexBinary *index = nullptr; int ncentroids = -1; int M; if (sscanf(description, "BIVF%d_HNSW%d", &ncentroids, &M) == 2) { IndexBinaryIVF *index_ivf = new IndexBinaryIVF( new IndexBinaryHNSW(d, M), d, ncentroids ); index_ivf->own_fields = true; index = index_ivf; } else if (sscanf(description, "BIVF%d", &ncentroids) == 1) { IndexBinaryIVF *index_ivf = new IndexBinaryIVF( new IndexBinaryFlat(d), d, ncentroids ); index_ivf->own_fields = true; index = index_ivf; } else if (sscanf(description, "BHNSW%d", &M) == 1) { IndexBinaryHNSW *index_hnsw = new IndexBinaryHNSW(d, M); index = index_hnsw; } else if (std::string(description) == "BFlat") { index = new IndexBinaryFlat(d); } else { FAISS_THROW_IF_NOT_FMT(index, "description %s did not generate an index", description); } return index; } } // namespace faiss