PaddleClas/deploy/cpp_shitu/src/vector_search.cpp

65 lines
2.1 KiB
C++
Raw Normal View History

2021-11-08 15:26:13 +00:00
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <faiss/index_factory.h>
#include <faiss/index_io.h>
#include <fstream>
#include <regex>
#include <iostream>
#include <cstdio>
#include "include/vector_search.h"
void VectorSearch::LoadIndexFile(){
std::string file_path = this->index_dir + OS_PATH_SEP + "vector.index";
const char* fname = file_path.c_str();
this->index = faiss::read_index(fname, 0);
}
void VectorSearch::LoadIdMap(){
std::string file_path = this->index_dir + OS_PATH_SEP + "id_map.txt";
std::ifstream in(file_path);
std::string line;
std::vector<std::string> m_vec;
if (in){
while (getline(in, line)){
std::regex ws_re("\\s+");
std::vector<std::string> v(
std::sregex_token_iterator(line.begin(), line.end(), ws_re, -1),
std::sregex_token_iterator());
if (v.size() !=2){
std::cout << "The number of element for each line in : " << file_path
<< "must be 2, exit the program..." << std::endl;
exit(1);
}else
this->id_map.insert(std::pair<long int, std::string>(std::stol(v[0], nullptr, 10), v[1]));
}
}
}
void VectorSearch::Search(float *feature, int query_number){
this->index->search(query_number, feature, return_k, D, I);
this->real_query_number = query_number;
}
const SearchResult& VectorSearch::GetSearchResult(){
this->sr.query_number = this->real_query_number;
this->sr.return_k = this->return_k;
this->sr.D = this->D;
this->sr.I = this->I;
return this->sr;
}
const std::string& VectorSearch::GetLabel(faiss::Index::idx_t ind){
return this->id_map.at(ind);
}