fix faiss bug
parent
d0c01a97e6
commit
4cd9dc0e05
|
@ -26,45 +26,45 @@
|
||||||
#include <map>
|
#include <map>
|
||||||
|
|
||||||
struct SearchResult {
|
struct SearchResult {
|
||||||
std::vector <faiss::Index::idx_t> I;
|
std::vector<faiss::idx_t> I;
|
||||||
std::vector<float> D;
|
std::vector<float> D;
|
||||||
int return_k;
|
int return_k;
|
||||||
};
|
};
|
||||||
|
|
||||||
class VectorSearch {
|
class VectorSearch {
|
||||||
public:
|
public:
|
||||||
explicit VectorSearch(const YAML::Node &config_file) {
|
explicit VectorSearch(const YAML::Node &config_file) {
|
||||||
// IndexProcess
|
// IndexProcess
|
||||||
this->index_dir =
|
this->index_dir =
|
||||||
config_file["IndexProcess"]["index_dir"].as<std::string>();
|
config_file["IndexProcess"]["index_dir"].as<std::string>();
|
||||||
this->return_k = config_file["IndexProcess"]["return_k"].as<int>();
|
this->return_k = config_file["IndexProcess"]["return_k"].as<int>();
|
||||||
this->score_thres = config_file["IndexProcess"]["score_thres"].as<float>();
|
this->score_thres = config_file["IndexProcess"]["score_thres"].as<float>();
|
||||||
this->max_query_number =
|
this->max_query_number =
|
||||||
config_file["Global"]["max_det_results"].as<int>() + 1;
|
config_file["Global"]["max_det_results"].as<int>() + 1;
|
||||||
LoadIdMap();
|
LoadIdMap();
|
||||||
LoadIndexFile();
|
LoadIndexFile();
|
||||||
this->I.resize(this->return_k * this->max_query_number);
|
this->I.resize(this->return_k * this->max_query_number);
|
||||||
this->D.resize(this->return_k * this->max_query_number);
|
this->D.resize(this->return_k * this->max_query_number);
|
||||||
};
|
};
|
||||||
|
|
||||||
void LoadIdMap();
|
void LoadIdMap();
|
||||||
|
|
||||||
void LoadIndexFile();
|
void LoadIndexFile();
|
||||||
|
|
||||||
const SearchResult &Search(float *feature, int query_number);
|
const SearchResult &Search(float *feature, int query_number);
|
||||||
|
|
||||||
const std::string &GetLabel(faiss::Index::idx_t ind);
|
const std::string &GetLabel(faiss::idx_t ind);
|
||||||
|
|
||||||
const float &GetThreshold() { return this->score_thres; }
|
const float &GetThreshold() { return this->score_thres; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::string index_dir;
|
std::string index_dir;
|
||||||
int return_k = 5;
|
int return_k = 5;
|
||||||
float score_thres = 0.5;
|
float score_thres = 0.5;
|
||||||
std::map<long int, std::string> id_map;
|
std::map<long int, std::string> id_map;
|
||||||
faiss::Index *index;
|
faiss::Index *index;
|
||||||
int max_query_number = 6;
|
int max_query_number = 6;
|
||||||
std::vector<float> D;
|
std::vector<float> D;
|
||||||
std::vector <faiss::Index::idx_t> I;
|
std::vector<faiss::idx_t> I;
|
||||||
SearchResult sr;
|
SearchResult sr;
|
||||||
};
|
};
|
||||||
|
|
|
@ -20,43 +20,43 @@
|
||||||
#include <regex>
|
#include <regex>
|
||||||
|
|
||||||
void VectorSearch::LoadIndexFile() {
|
void VectorSearch::LoadIndexFile() {
|
||||||
std::string file_path = this->index_dir + OS_PATH_SEP + "vector.index";
|
std::string file_path = this->index_dir + OS_PATH_SEP + "vector.index";
|
||||||
const char *fname = file_path.c_str();
|
const char *fname = file_path.c_str();
|
||||||
this->index = faiss::read_index(fname, 0);
|
this->index = faiss::read_index(fname, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
void VectorSearch::LoadIdMap() {
|
void VectorSearch::LoadIdMap() {
|
||||||
std::string file_path = this->index_dir + OS_PATH_SEP + "id_map.txt";
|
std::string file_path = this->index_dir + OS_PATH_SEP + "id_map.txt";
|
||||||
std::ifstream in(file_path);
|
std::ifstream in(file_path);
|
||||||
std::string line;
|
std::string line;
|
||||||
std::vector <std::string> m_vec;
|
std::vector<std::string> m_vec;
|
||||||
if (in) {
|
if (in) {
|
||||||
while (getline(in, line)) {
|
while (getline(in, line)) {
|
||||||
std::regex ws_re("\\s+");
|
std::regex ws_re("\\s+");
|
||||||
std::vector <std::string> v(
|
std::vector<std::string> v(
|
||||||
std::sregex_token_iterator(line.begin(), line.end(), ws_re, -1),
|
std::sregex_token_iterator(line.begin(), line.end(), ws_re, -1),
|
||||||
std::sregex_token_iterator());
|
std::sregex_token_iterator());
|
||||||
if (v.size() != 2) {
|
if (v.size() != 2) {
|
||||||
std::cout << "The number of element for each line in : " << file_path
|
std::cout << "The number of element for each line in : " << file_path
|
||||||
<< "must be 2, exit the program..." << std::endl;
|
<< "must be 2, exit the program..." << std::endl;
|
||||||
exit(1);
|
exit(1);
|
||||||
} else
|
} else
|
||||||
this->id_map.insert(std::pair<long int, std::string>(
|
this->id_map.insert(std::pair<long int, std::string>(
|
||||||
std::stol(v[0], nullptr, 10), v[1]));
|
std::stol(v[0], nullptr, 10), v[1]));
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const SearchResult &VectorSearch::Search(float *feature, int query_number) {
|
const SearchResult &VectorSearch::Search(float *feature, int query_number) {
|
||||||
this->D.resize(this->return_k * query_number);
|
this->D.resize(this->return_k * query_number);
|
||||||
this->I.resize(this->return_k * query_number);
|
this->I.resize(this->return_k * query_number);
|
||||||
this->index->search(query_number, feature, return_k, D.data(), I.data());
|
this->index->search(query_number, feature, return_k, D.data(), I.data());
|
||||||
this->sr.return_k = this->return_k;
|
this->sr.return_k = this->return_k;
|
||||||
this->sr.D = this->D;
|
this->sr.D = this->D;
|
||||||
this->sr.I = this->I;
|
this->sr.I = this->I;
|
||||||
return this->sr;
|
return this->sr;
|
||||||
}
|
}
|
||||||
|
|
||||||
const std::string &VectorSearch::GetLabel(faiss::Index::idx_t ind) {
|
const std::string &VectorSearch::GetLabel(faiss::idx_t ind) {
|
||||||
return this->id_map.at(ind);
|
return this->id_map.at(ind);
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue