341 lines
13 KiB
Python
341 lines
13 KiB
Python
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||
#
|
||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
# you may not use this file except in compliance with the License.
|
||
# You may obtain a copy of the License at
|
||
#
|
||
# http://www.apache.org/licenses/LICENSE-2.0
|
||
#
|
||
# Unless required by applicable law or agreed to in writing, software
|
||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
# See the License for the specific language governing permissions and
|
||
# limitations under the License.
|
||
import os
|
||
import sys
|
||
from PyQt5 import QtCore, QtGui, QtWidgets
|
||
import mod.mainwindow
|
||
|
||
from paddleclas.deploy.utils import config, logger
|
||
from paddleclas.deploy.python.predict_rec import RecPredictor
|
||
from fastapi import FastAPI
|
||
import uvicorn
|
||
import numpy as np
|
||
import faiss
|
||
from typing import List
|
||
import pickle
|
||
import cv2
|
||
import socket
|
||
import json
|
||
import operator
|
||
from multiprocessing import Process
|
||
"""
|
||
完整的index库如下:
|
||
root_path/ # 库存储目录
|
||
|-- image_list.txt # 图像列表,每行:image_path label。由前端生成及修改。后端只读
|
||
|-- features.pkl # 建库之后,保存的embedding向量,后端生成,前端无需操作
|
||
|-- images # 图像存储目录,由前端生成及增删查等操作。后端只读
|
||
| |-- md5.jpg
|
||
| |-- md5.jpg
|
||
| |-- ……
|
||
|-- index # 真正的生成的index库存储目录,后端生成及操作,前端无需操作。
|
||
| |-- vector.index # faiss生成的索引库
|
||
| |-- id_map.pkl # 索引文件
|
||
"""
|
||
|
||
|
||
class ShiTuIndexManager(object):
|
||
def __init__(self, config):
|
||
self.root_path = None
|
||
self.image_list_path = "image_list.txt"
|
||
self.image_dir = "images"
|
||
self.index_path = "index/vector.index"
|
||
self.id_map_path = "index/id_map.pkl"
|
||
self.features_path = "features.pkl"
|
||
self.index = None
|
||
self.id_map = None
|
||
self.features = None
|
||
self.config = config
|
||
self.predictor = RecPredictor(config)
|
||
|
||
def _load_pickle(self, path):
|
||
if os.path.exists(path):
|
||
return pickle.load(open(path, 'rb'))
|
||
else:
|
||
return None
|
||
|
||
def _save_pickle(self, path, data):
|
||
if not os.path.exists(os.path.dirname(path)):
|
||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||
with open(path, 'wb') as fd:
|
||
pickle.dump(data, fd)
|
||
|
||
def _load_index(self):
|
||
self.index = faiss.read_index(
|
||
os.path.join(self.root_path, self.index_path))
|
||
self.id_map = self._load_pickle(
|
||
os.path.join(self.root_path, self.id_map_path))
|
||
self.features = self._load_pickle(
|
||
os.path.join(self.root_path, self.features_path))
|
||
|
||
def _save_index(self, index, id_map, features):
|
||
faiss.write_index(index, os.path.join(self.root_path, self.index_path))
|
||
self._save_pickle(
|
||
os.path.join(self.root_path, self.id_map_path), id_map)
|
||
self._save_pickle(
|
||
os.path.join(self.root_path, self.features_path), features)
|
||
|
||
def _update_path(self, root_path, image_list_path=None):
|
||
if root_path == self.root_path:
|
||
pass
|
||
else:
|
||
self.root_path = root_path
|
||
if not os.path.exists(os.path.join(root_path, "index")):
|
||
os.mkdir(os.path.join(root_path, "index"))
|
||
if image_list_path is not None:
|
||
self.image_list_path = image_list_path
|
||
|
||
def _cal_featrue(self, image_list):
|
||
batch_images = []
|
||
featrures = None
|
||
cnt = 0
|
||
for idx, image_path in enumerate(image_list):
|
||
image = cv2.imread(image_path)
|
||
if image is None:
|
||
return "{} is broken or not exist. Stop"
|
||
else:
|
||
image = image[:, :, ::-1]
|
||
batch_images.append(image)
|
||
cnt += 1
|
||
if cnt % self.config["Global"]["batch_size"] == 0 or (
|
||
idx + 1) == len(image_list):
|
||
if len(batch_images) == 0:
|
||
continue
|
||
batch_results = self.predictor.predict(batch_images)
|
||
featrures = batch_results if featrures is None else np.concatenate(
|
||
(featrures, batch_results), axis=0)
|
||
batch_images = []
|
||
return featrures
|
||
|
||
def _split_datafile(self, data_file, image_root):
|
||
'''
|
||
data_file: image path and info, which can be splitted by spacer
|
||
image_root: image path root
|
||
delimiter: delimiter
|
||
'''
|
||
gallery_images = []
|
||
gallery_docs = []
|
||
gallery_ids = []
|
||
with open(data_file, 'r', encoding='utf-8') as f:
|
||
lines = f.readlines()
|
||
for _, ori_line in enumerate(lines):
|
||
line = ori_line.strip().split()
|
||
text_num = len(line)
|
||
assert text_num >= 2, f"line({ori_line}) must be splitted into at least 2 parts, but got {text_num}"
|
||
image_file = os.path.join(image_root, line[0])
|
||
|
||
gallery_images.append(image_file)
|
||
gallery_docs.append(ori_line.strip())
|
||
gallery_ids.append(os.path.basename(line[0]).split(".")[0])
|
||
|
||
return gallery_images, gallery_docs, gallery_ids
|
||
|
||
def create_index(self,
|
||
image_list: str,
|
||
index_method: str="HNSW32",
|
||
image_root: str=None):
|
||
if not os.path.exists(image_list):
|
||
return "{} is not exist".format(image_list)
|
||
if index_method.lower() not in ['hnsw32', 'ivf', 'flat']:
|
||
return "The index method Only support: HNSW32, IVF, Flat"
|
||
self._update_path(os.path.dirname(image_list), image_list)
|
||
|
||
# get image_paths
|
||
image_root = image_root if image_root is not None else self.root_path
|
||
gallery_images, gallery_docs, image_ids = self._split_datafile(
|
||
image_list, image_root)
|
||
|
||
# gernerate index
|
||
if index_method == "IVF":
|
||
index_method = index_method + str(
|
||
min(max(int(len(gallery_images) // 32), 2), 65536)) + ",Flat"
|
||
index = faiss.index_factory(
|
||
self.config["IndexProcess"]["embedding_size"], index_method,
|
||
faiss.METRIC_INNER_PRODUCT)
|
||
self.index = faiss.IndexIDMap2(index)
|
||
features = self._cal_featrue(gallery_images)
|
||
self.index.train(features)
|
||
index_ids = np.arange(0, len(gallery_images)).astype(np.int64)
|
||
self.index.add_with_ids(features, index_ids)
|
||
|
||
self.id_map = dict()
|
||
for i, d in zip(list(index_ids), gallery_docs):
|
||
self.id_map[i] = d
|
||
|
||
self.features = {
|
||
"features": features,
|
||
"index_method": index_method,
|
||
"image_ids": image_ids,
|
||
"index_ids": index_ids.tolist()
|
||
}
|
||
self._save_index(self.index, self.id_map, self.features)
|
||
|
||
def open_index(self, root_path: str, image_list_path: str) -> str:
|
||
self._update_path(root_path)
|
||
_, _, image_ids = self._split_datafile(image_list_path, root_path)
|
||
if os.path.exists(os.path.join(self.root_path, self.index_path)) and \
|
||
os.path.exists(os.path.join(self.root_path, self.id_map_path)) and \
|
||
os.path.exists(os.path.join(self.root_path, self.features_path)):
|
||
self._update_path(root_path)
|
||
self._load_index()
|
||
if operator.eq(set(image_ids), set(self.features['image_ids'])):
|
||
return ""
|
||
else:
|
||
return "The image list is different from index, Please update index"
|
||
else:
|
||
return "File not exist: features.pkl, vector.index, id_map.pkl"
|
||
|
||
def update_index(self, image_list: str, image_root: str=None) -> str:
|
||
if self.index and self.id_map and self.features:
|
||
image_paths, image_docs, image_ids = self._split_datafile(
|
||
image_list, image_root
|
||
if image_root is not None else self.root_path)
|
||
|
||
# for add image
|
||
add_ids = list(
|
||
set(image_ids).difference(set(self.features["image_ids"])))
|
||
add_indexes = [i for i, x in enumerate(image_ids) if x in add_ids]
|
||
add_image_paths = [image_paths[i] for i in add_indexes]
|
||
add_image_docs = [image_docs[i] for i in add_indexes]
|
||
add_image_ids = [image_ids[i] for i in add_indexes]
|
||
self._add_index(add_image_paths, add_image_docs, add_image_ids)
|
||
|
||
# delete images
|
||
delete_ids = list(
|
||
set(self.features["image_ids"]).difference(set(image_ids)))
|
||
self._delete_index(delete_ids)
|
||
self._save_index(self.index, self.id_map, self.features)
|
||
return ""
|
||
else:
|
||
return "Failed. Please create or open index first"
|
||
|
||
def _add_index(self, image_list: List, image_docs: List, image_ids: List):
|
||
if len(image_ids) == 0:
|
||
return
|
||
featrures = self._cal_featrue(image_list)
|
||
index_ids = (
|
||
np.arange(0, len(image_list)) + max(self.id_map.keys()) + 1
|
||
).astype(np.int64)
|
||
self.index.add_with_ids(featrures, index_ids)
|
||
|
||
for i, d in zip(index_ids, image_docs):
|
||
self.id_map[i] = d
|
||
|
||
self.features['features'] = np.concatenate(
|
||
[self.features['features'], featrures], axis=0)
|
||
self.features['image_ids'].extend(image_ids)
|
||
self.features['index_ids'].extend(index_ids.tolist())
|
||
|
||
def _delete_index(self, image_ids: List):
|
||
if len(image_ids) == 0:
|
||
return
|
||
indexes = [
|
||
i for i, x in enumerate(self.features['image_ids'])
|
||
if x in image_ids
|
||
]
|
||
self.features["features"] = np.delete(
|
||
self.features["features"], indexes, axis=0)
|
||
self.features["image_ids"] = np.delete(
|
||
np.asarray(self.features["image_ids"]), indexes, axis=0).tolist()
|
||
index_ids = np.delete(
|
||
np.asarray(self.features["index_ids"]), indexes, axis=0).tolist()
|
||
id_map_values = [self.id_map[i] for i in index_ids]
|
||
self.index.reset()
|
||
ids = np.arange(0, len(id_map_values)).astype(np.int64)
|
||
self.index.add_with_ids(self.features['features'], ids)
|
||
self.id_map.clear()
|
||
for i, d in zip(ids, id_map_values):
|
||
self.id_map[i] = d
|
||
self.features["index_ids"] = ids
|
||
|
||
|
||
app = FastAPI()
|
||
|
||
|
||
@app.get("/new_index")
|
||
def new_index(image_list_path: str,
|
||
index_method: str="HNSW32",
|
||
index_root_path: str=None,
|
||
force: bool=False):
|
||
result = ""
|
||
try:
|
||
if index_root_path is not None:
|
||
image_list_path = os.path.join(index_root_path, image_list_path)
|
||
index_path = os.path.join(index_root_path, "index", "vector.index")
|
||
id_map_path = os.path.join(index_root_path, "index", "id_map.pkl")
|
||
|
||
if not (os.path.exists(index_path) and
|
||
os.path.exists(id_map_path)) or force:
|
||
manager.create_index(image_list_path, index_method,
|
||
index_root_path)
|
||
else:
|
||
result = "There alrealy has index in {}".format(index_root_path)
|
||
except Exception as e:
|
||
result = e.__str__()
|
||
data = {"error_message": result}
|
||
return json.dumps(data).encode()
|
||
|
||
|
||
@app.get("/open_index")
|
||
def open_index(index_root_path: str, image_list_path: str):
|
||
result = ""
|
||
try:
|
||
image_list_path = os.path.join(index_root_path, image_list_path)
|
||
result = manager.open_index(index_root_path, image_list_path)
|
||
except Exception as e:
|
||
result = e.__str__()
|
||
|
||
data = {"error_message": result}
|
||
return json.dumps(data).encode()
|
||
|
||
|
||
@app.get("/update_index")
|
||
def update_index(image_list_path: str, index_root_path: str=None):
|
||
result = ""
|
||
try:
|
||
if index_root_path is not None:
|
||
image_list_path = os.path.join(index_root_path, image_list_path)
|
||
result = manager.update_index(
|
||
image_list=image_list_path, image_root=index_root_path)
|
||
except Exception as e:
|
||
result = e.__str__()
|
||
data = {"error_message": result}
|
||
return json.dumps(data).encode()
|
||
|
||
|
||
def FrontInterface(server_process=None):
|
||
front = QtWidgets.QApplication([])
|
||
main_window = mod.mainwindow.MainWindow(process=server_process)
|
||
main_window.showMaximized()
|
||
sys.exit(front.exec_())
|
||
|
||
|
||
def Server(app, host, port):
|
||
uvicorn.run(app, host=host, port=port)
|
||
|
||
|
||
if __name__ == '__main__':
|
||
args = config.parse_args()
|
||
model_config = config.get_config(
|
||
args.config, overrides=args.override, show=True)
|
||
manager = ShiTuIndexManager(model_config)
|
||
ip = model_config.get('ip', None)
|
||
port = model_config.get('port', None)
|
||
if ip is None or port is None:
|
||
try:
|
||
ip = socket.gethostbyname(socket.gethostname())
|
||
except:
|
||
ip = '127.0.0.1'
|
||
port = 8000
|
||
Server(app, ip, port)
|