mirror of
https://github.com/PaddlePaddle/PaddleClas.git
synced 2025-06-03 21:55:06 +08:00
65 lines
2.1 KiB
C++
65 lines
2.1 KiB
C++
|
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||
|
//
|
||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
// you may not use this file except in compliance with the License.
|
||
|
// You may obtain a copy of the License at
|
||
|
//
|
||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||
|
//
|
||
|
// Unless required by applicable law or agreed to in writing, software
|
||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
// See the License for the specific language governing permissions and
|
||
|
// limitations under the License.
|
||
|
#include <faiss/index_factory.h>
|
||
|
#include <faiss/index_io.h>
|
||
|
#include <fstream>
|
||
|
#include <regex>
|
||
|
#include <iostream>
|
||
|
#include <cstdio>
|
||
|
#include "include/vector_search.h"
|
||
|
|
||
|
void VectorSearch::LoadIndexFile(){
|
||
|
std::string file_path = this->index_dir + OS_PATH_SEP + "vector.index";
|
||
|
const char* fname = file_path.c_str();
|
||
|
this->index = faiss::read_index(fname, 0);
|
||
|
}
|
||
|
|
||
|
void VectorSearch::LoadIdMap(){
|
||
|
std::string file_path = this->index_dir + OS_PATH_SEP + "id_map.txt";
|
||
|
std::ifstream in(file_path);
|
||
|
std::string line;
|
||
|
std::vector<std::string> m_vec;
|
||
|
if (in){
|
||
|
while (getline(in, line)){
|
||
|
std::regex ws_re("\\s+");
|
||
|
std::vector<std::string> v(
|
||
|
std::sregex_token_iterator(line.begin(), line.end(), ws_re, -1),
|
||
|
std::sregex_token_iterator());
|
||
|
if (v.size() !=2){
|
||
|
std::cout << "The number of element for each line in : " << file_path
|
||
|
<< "must be 2, exit the program..." << std::endl;
|
||
|
exit(1);
|
||
|
}else
|
||
|
this->id_map.insert(std::pair<long int, std::string>(std::stol(v[0], nullptr, 10), v[1]));
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
void VectorSearch::Search(float *feature, int query_number){
|
||
|
this->index->search(query_number, feature, return_k, D, I);
|
||
|
this->real_query_number = query_number;
|
||
|
}
|
||
|
|
||
|
const SearchResult& VectorSearch::GetSearchResult(){
|
||
|
this->sr.query_number = this->real_query_number;
|
||
|
this->sr.return_k = this->return_k;
|
||
|
this->sr.D = this->D;
|
||
|
this->sr.I = this->I;
|
||
|
return this->sr;
|
||
|
}
|
||
|
|
||
|
const std::string& VectorSearch::GetLabel(faiss::Index::idx_t ind){
|
||
|
return this->id_map.at(ind);
|
||
|
}
|