fix shitu_index manager bug
parent
a33706ffb9
commit
e62054f4a3
deploy/shitu_index_manager
docs/zh_CN/inference_deployment
|
@ -0,0 +1,45 @@
|
|||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
import sys
|
||||
from PyQt5 import QtCore, QtGui, QtWidgets
|
||||
import mod.mainwindow
|
||||
"""
|
||||
完整的index库如下:
|
||||
root_path/ # 库存储目录
|
||||
|-- image_list.txt # 图像列表,每行:image_path label。由前端生成及修改。后端只读
|
||||
|-- features.pkl # 建库之后,保存的embedding向量,后端生成,前端无需操作
|
||||
|-- images # 图像存储目录,由前端生成及增删查等操作。后端只读
|
||||
| |-- md5.jpg
|
||||
| |-- md5.jpg
|
||||
| |-- ……
|
||||
|-- index # 真正的生成的index库存储目录,后端生成及操作,前端无需操作。
|
||||
| |-- vector.index # faiss生成的索引库
|
||||
| |-- id_map.pkl # 索引文件
|
||||
"""
|
||||
|
||||
|
||||
def FrontInterface(server_ip=None, server_port=None):
|
||||
front = QtWidgets.QApplication([])
|
||||
main_window = mod.mainwindow.MainWindow(ip=server_ip, port=server_port)
|
||||
main_window.showMaximized()
|
||||
sys.exit(front.exec_())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
server_ip = None
|
||||
server_port = None
|
||||
if len(sys.argv) == 2 and len(sys.argv[1].split(' ')) == 2:
|
||||
[server_ip, server_port] = sys.argv[1].split(' ')
|
||||
FrontInterface(server_ip, server_port)
|
|
@ -13,22 +13,10 @@
|
|||
# limitations under the License.
|
||||
import os
|
||||
import sys
|
||||
from PyQt5 import QtCore, QtGui, QtWidgets
|
||||
import mod.mainwindow
|
||||
|
||||
from paddleclas.deploy.utils import config, logger
|
||||
from paddleclas.deploy.python.predict_rec import RecPredictor
|
||||
from fastapi import FastAPI
|
||||
import uvicorn
|
||||
import numpy as np
|
||||
import faiss
|
||||
from typing import List
|
||||
import pickle
|
||||
import cv2
|
||||
import socket
|
||||
import json
|
||||
import operator
|
||||
from multiprocessing import Process
|
||||
import subprocess
|
||||
import shlex
|
||||
import psutil
|
||||
import time
|
||||
"""
|
||||
完整的index库如下:
|
||||
root_path/ # 库存储目录
|
||||
|
@ -43,307 +31,34 @@ root_path/ # 库存储目录
|
|||
| |-- id_map.pkl # 索引文件
|
||||
"""
|
||||
|
||||
|
||||
class ShiTuIndexManager(object):
|
||||
|
||||
def __init__(self, config):
|
||||
self.root_path = None
|
||||
self.image_list_path = "image_list.txt"
|
||||
self.image_dir = "images"
|
||||
self.index_path = "index/vector.index"
|
||||
self.id_map_path = "index/id_map.pkl"
|
||||
self.features_path = "features.pkl"
|
||||
self.index = None
|
||||
self.id_map = None
|
||||
self.features = None
|
||||
self.config = config
|
||||
self.predictor = RecPredictor(config)
|
||||
|
||||
def _load_pickle(self, path):
|
||||
if os.path.exists(path):
|
||||
return pickle.load(open(path, 'rb'))
|
||||
else:
|
||||
return None
|
||||
|
||||
def _save_pickle(self, path, data):
|
||||
if not os.path.exists(os.path.dirname(path)):
|
||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||
with open(path, 'wb') as fd:
|
||||
pickle.dump(data, fd)
|
||||
|
||||
def _load_index(self):
|
||||
self.index = faiss.read_index(
|
||||
os.path.join(self.root_path, self.index_path))
|
||||
self.id_map = self._load_pickle(
|
||||
os.path.join(self.root_path, self.id_map_path))
|
||||
self.features = self._load_pickle(
|
||||
os.path.join(self.root_path, self.features_path))
|
||||
|
||||
def _save_index(self, index, id_map, features):
|
||||
faiss.write_index(index, os.path.join(self.root_path, self.index_path))
|
||||
self._save_pickle(os.path.join(self.root_path, self.id_map_path),
|
||||
id_map)
|
||||
self._save_pickle(os.path.join(self.root_path, self.features_path),
|
||||
features)
|
||||
|
||||
def _update_path(self, root_path, image_list_path=None):
|
||||
if root_path == self.root_path:
|
||||
pass
|
||||
else:
|
||||
self.root_path = root_path
|
||||
if not os.path.exists(os.path.join(root_path, "index")):
|
||||
os.mkdir(os.path.join(root_path, "index"))
|
||||
if image_list_path is not None:
|
||||
self.image_list_path = image_list_path
|
||||
|
||||
def _cal_featrue(self, image_list):
|
||||
batch_images = []
|
||||
featrures = None
|
||||
cnt = 0
|
||||
for idx, image_path in enumerate(image_list):
|
||||
image = cv2.imread(image_path)
|
||||
if image is None:
|
||||
return "{} is broken or not exist. Stop"
|
||||
else:
|
||||
image = image[:, :, ::-1]
|
||||
batch_images.append(image)
|
||||
cnt += 1
|
||||
if cnt % self.config["Global"]["batch_size"] == 0 or (
|
||||
idx + 1) == len(image_list):
|
||||
if len(batch_images) == 0:
|
||||
continue
|
||||
batch_results = self.predictor.predict(batch_images)
|
||||
featrures = batch_results if featrures is None else np.concatenate(
|
||||
(featrures, batch_results), axis=0)
|
||||
batch_images = []
|
||||
return featrures
|
||||
|
||||
def _split_datafile(self, data_file, image_root):
|
||||
'''
|
||||
data_file: image path and info, which can be splitted by spacer
|
||||
image_root: image path root
|
||||
delimiter: delimiter
|
||||
'''
|
||||
gallery_images = []
|
||||
gallery_docs = []
|
||||
gallery_ids = []
|
||||
with open(data_file, 'r', encoding='utf-8') as f:
|
||||
lines = f.readlines()
|
||||
for _, ori_line in enumerate(lines):
|
||||
line = ori_line.strip().split()
|
||||
text_num = len(line)
|
||||
assert text_num >= 2, f"line({ori_line}) must be splitted into at least 2 parts, but got {text_num}"
|
||||
image_file = os.path.join(image_root, line[0])
|
||||
|
||||
gallery_images.append(image_file)
|
||||
gallery_docs.append(ori_line.strip())
|
||||
gallery_ids.append(os.path.basename(line[0]).split(".")[0])
|
||||
|
||||
return gallery_images, gallery_docs, gallery_ids
|
||||
|
||||
def create_index(self,
|
||||
image_list: str,
|
||||
index_method: str = "HNSW32",
|
||||
image_root: str = None):
|
||||
if not os.path.exists(image_list):
|
||||
return "{} is not exist".format(image_list)
|
||||
if index_method.lower() not in ['hnsw32', 'ivf', 'flat']:
|
||||
return "The index method Only support: HNSW32, IVF, Flat"
|
||||
self._update_path(os.path.dirname(image_list), image_list)
|
||||
|
||||
# get image_paths
|
||||
image_root = image_root if image_root is not None else self.root_path
|
||||
gallery_images, gallery_docs, image_ids = self._split_datafile(
|
||||
image_list, image_root)
|
||||
|
||||
# gernerate index
|
||||
if index_method == "IVF":
|
||||
index_method = index_method + str(
|
||||
min(max(int(len(gallery_images) // 32), 2), 65536)) + ",Flat"
|
||||
index = faiss.index_factory(
|
||||
self.config["IndexProcess"]["embedding_size"], index_method,
|
||||
faiss.METRIC_INNER_PRODUCT)
|
||||
self.index = faiss.IndexIDMap2(index)
|
||||
features = self._cal_featrue(gallery_images)
|
||||
self.index.train(features)
|
||||
index_ids = np.arange(0, len(gallery_images)).astype(np.int64)
|
||||
self.index.add_with_ids(features, index_ids)
|
||||
|
||||
self.id_map = dict()
|
||||
for i, d in zip(list(index_ids), gallery_docs):
|
||||
self.id_map[i] = d
|
||||
|
||||
self.features = {
|
||||
"features": features,
|
||||
"index_method": index_method,
|
||||
"image_ids": image_ids,
|
||||
"index_ids": index_ids.tolist()
|
||||
}
|
||||
self._save_index(self.index, self.id_map, self.features)
|
||||
|
||||
def open_index(self, root_path: str, image_list_path: str) -> str:
|
||||
self._update_path(root_path)
|
||||
_, _, image_ids = self._split_datafile(image_list_path, root_path)
|
||||
if os.path.exists(os.path.join(self.root_path, self.index_path)) and \
|
||||
os.path.exists(os.path.join(self.root_path, self.id_map_path)) and \
|
||||
os.path.exists(os.path.join(self.root_path, self.features_path)):
|
||||
self._update_path(root_path)
|
||||
self._load_index()
|
||||
if operator.eq(set(image_ids), set(self.features['image_ids'])):
|
||||
return ""
|
||||
else:
|
||||
return "The image list is different from index, Please update index"
|
||||
else:
|
||||
return "File not exist: features.pkl, vector.index, id_map.pkl"
|
||||
|
||||
def update_index(self, image_list: str, image_root: str = None) -> str:
|
||||
if self.index and self.id_map and self.features:
|
||||
image_paths, image_docs, image_ids = self._split_datafile(
|
||||
image_list,
|
||||
image_root if image_root is not None else self.root_path)
|
||||
|
||||
# for add image
|
||||
add_ids = list(
|
||||
set(image_ids).difference(set(self.features["image_ids"])))
|
||||
add_indexes = [i for i, x in enumerate(image_ids) if x in add_ids]
|
||||
add_image_paths = [image_paths[i] for i in add_indexes]
|
||||
add_image_docs = [image_docs[i] for i in add_indexes]
|
||||
add_image_ids = [image_ids[i] for i in add_indexes]
|
||||
self._add_index(add_image_paths, add_image_docs, add_image_ids)
|
||||
|
||||
# delete images
|
||||
delete_ids = list(
|
||||
set(self.features["image_ids"]).difference(set(image_ids)))
|
||||
self._delete_index(delete_ids)
|
||||
self._save_index(self.index, self.id_map, self.features)
|
||||
return ""
|
||||
else:
|
||||
return "Failed. Please create or open index first"
|
||||
|
||||
def _add_index(self, image_list: List, image_docs: List, image_ids: List):
|
||||
if len(image_ids) == 0:
|
||||
return
|
||||
featrures = self._cal_featrue(image_list)
|
||||
index_ids = (np.arange(0, len(image_list)) + max(self.id_map.keys()) +
|
||||
1).astype(np.int64)
|
||||
self.index.add_with_ids(featrures, index_ids)
|
||||
|
||||
for i, d in zip(index_ids, image_docs):
|
||||
self.id_map[i] = d
|
||||
|
||||
self.features['features'] = np.concatenate(
|
||||
[self.features['features'], featrures], axis=0)
|
||||
self.features['image_ids'].extend(image_ids)
|
||||
self.features['index_ids'].extend(index_ids.tolist())
|
||||
|
||||
def _delete_index(self, image_ids: List):
|
||||
if len(image_ids) == 0:
|
||||
return
|
||||
indexes = [
|
||||
i for i, x in enumerate(self.features['image_ids'])
|
||||
if x in image_ids
|
||||
]
|
||||
self.features["features"] = np.delete(self.features["features"],
|
||||
indexes,
|
||||
axis=0)
|
||||
self.features["image_ids"] = np.delete(np.asarray(
|
||||
self.features["image_ids"]),
|
||||
indexes,
|
||||
axis=0).tolist()
|
||||
index_ids = np.delete(np.asarray(self.features["index_ids"]),
|
||||
indexes,
|
||||
axis=0).tolist()
|
||||
id_map_values = [self.id_map[i] for i in index_ids]
|
||||
self.index.reset()
|
||||
ids = np.arange(0, len(id_map_values)).astype(np.int64)
|
||||
self.index.add_with_ids(self.features['features'], ids)
|
||||
self.id_map.clear()
|
||||
for i, d in zip(ids, id_map_values):
|
||||
self.id_map[i] = d
|
||||
self.features["index_ids"] = ids
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
@app.get("/new_index")
|
||||
def new_index(image_list_path: str,
|
||||
index_method: str = "HNSW32",
|
||||
index_root_path: str = None,
|
||||
force: bool = False):
|
||||
result = ""
|
||||
try:
|
||||
if index_root_path is not None:
|
||||
image_list_path = os.path.join(index_root_path, image_list_path)
|
||||
index_path = os.path.join(index_root_path, "index", "vector.index")
|
||||
id_map_path = os.path.join(index_root_path, "index", "id_map.pkl")
|
||||
|
||||
if not (os.path.exists(index_path)
|
||||
and os.path.exists(id_map_path)) or force:
|
||||
manager.create_index(image_list_path, index_method, index_root_path)
|
||||
else:
|
||||
result = "There alrealy has index in {}".format(index_root_path)
|
||||
except Exception as e:
|
||||
result = e.__str__()
|
||||
data = {"error_message": result}
|
||||
return json.dumps(data).encode()
|
||||
|
||||
|
||||
@app.get("/open_index")
|
||||
def open_index(index_root_path: str, image_list_path: str):
|
||||
result = ""
|
||||
try:
|
||||
image_list_path = os.path.join(index_root_path, image_list_path)
|
||||
result = manager.open_index(index_root_path, image_list_path)
|
||||
except Exception as e:
|
||||
result = e.__str__()
|
||||
|
||||
data = {"error_message": result}
|
||||
return json.dumps(data).encode()
|
||||
|
||||
|
||||
@app.get("/update_index")
|
||||
def update_index(image_list_path: str, index_root_path: str = None):
|
||||
result = ""
|
||||
try:
|
||||
if index_root_path is not None:
|
||||
image_list_path = os.path.join(index_root_path, image_list_path)
|
||||
result = manager.update_index(image_list=image_list_path,
|
||||
image_root=index_root_path)
|
||||
except Exception as e:
|
||||
result = e.__str__()
|
||||
data = {"error_message": result}
|
||||
return json.dumps(data).encode()
|
||||
|
||||
|
||||
def FrontInterface(server_process=None):
|
||||
front = QtWidgets.QApplication([])
|
||||
main_window = mod.mainwindow.MainWindow(process=server_process)
|
||||
main_window.showMaximized()
|
||||
sys.exit(front.exec_())
|
||||
|
||||
|
||||
def Server(args):
|
||||
[app, host, port] = args
|
||||
uvicorn.run(app, host=host, port=port)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = config.parse_args()
|
||||
model_config = config.get_config(args.config,
|
||||
overrides=args.override,
|
||||
show=True)
|
||||
manager = ShiTuIndexManager(model_config)
|
||||
if not (len(sys.argv) == 3 or len(sys.argv) == 5):
|
||||
print("start example:")
|
||||
print(" python index_manager.py -c xxx.yaml")
|
||||
print(" python index_manager.py -c xxx.yaml -p port")
|
||||
yaml_path = sys.argv[2]
|
||||
if len(sys.argv) == 5:
|
||||
port = sys.argv[4]
|
||||
else:
|
||||
port = 8000
|
||||
assert int(port) > 1024 and int(
|
||||
port) < 65536, "The port should be bigger than 1024 and \
|
||||
smaller than 65536"
|
||||
|
||||
try:
|
||||
ip = socket.gethostbyname(socket.gethostname())
|
||||
except:
|
||||
ip = '127.0.0.1'
|
||||
port = 8000
|
||||
p_server = Process(target=Server, args=([app, ip, port],))
|
||||
p_server.start()
|
||||
# p_client = Process(target=FrontInterface, args=())
|
||||
# p_client.start()
|
||||
# p_client.join()
|
||||
FrontInterface(p_server)
|
||||
p_server.terminate()
|
||||
sys.exit(0)
|
||||
server_cmd = "python server.py -c {} -o ip={} -o port={}".format(yaml_path,
|
||||
ip, port)
|
||||
server_proc = subprocess.Popen(shlex.split(server_cmd))
|
||||
client_proc = subprocess.Popen(
|
||||
["python", "client.py", "{} {}".format(ip, port)])
|
||||
try:
|
||||
while psutil.Process(client_proc.pid).status() == "running":
|
||||
time.sleep(0.5)
|
||||
except:
|
||||
pass
|
||||
|
||||
client_proc.terminate()
|
||||
server_proc.terminate()
|
||||
|
|
|
@ -22,8 +22,6 @@ try:
|
|||
DEFAULT_HOST = socket.gethostbyname(socket.gethostname())
|
||||
except:
|
||||
DEFAULT_HOST = '127.0.0.1'
|
||||
|
||||
# DEFAULT_HOST = "localhost"
|
||||
DEFAULT_PORT = 8000
|
||||
PADDLECLAS_DOC_URL = "https://gitee.com/paddlepaddle/PaddleClas/docs/zh_CN/inference_deployment/shitu_gallery_manager.md"
|
||||
|
||||
|
@ -35,12 +33,17 @@ class MainWindow(QtWidgets.QMainWindow):
|
|||
updateIndexMsg = QtCore.pyqtSignal(str) # 更新索引库线程信号
|
||||
importImageCount = QtCore.pyqtSignal(int) # 导入图像数量信号
|
||||
|
||||
def __init__(self, process=None):
|
||||
def __init__(self, ip=None, port=None):
|
||||
super(MainWindow, self).__init__()
|
||||
self.server_process = process
|
||||
if ip is not None and port is not None:
|
||||
self.server_ip = ip
|
||||
self.server_port = port
|
||||
else:
|
||||
self.server_ip = DEFAULT_HOST
|
||||
self.server_port = DEFAULT_PORT
|
||||
|
||||
self.ui = ui_mainwindow.Ui_MainWindow()
|
||||
self.ui.setupUi(self) # 初始化主窗口界面
|
||||
|
||||
self.__imageListMgr = image_list_manager.ImageListManager()
|
||||
|
||||
self.__appMenu = QtWidgets.QMenu() # 应用菜单
|
||||
|
@ -115,8 +118,7 @@ class MainWindow(QtWidgets.QMainWindow):
|
|||
self.ui.saveImageLibraryBtn.clicked.connect(self.saveImageLibrary)
|
||||
|
||||
self.__setToolButton(self.ui.addClassifyBtn, "添加分类",
|
||||
"./resource/add_classify.png",
|
||||
TOOL_BTN_ICON_SIZE)
|
||||
"./resource/add_classify.png", TOOL_BTN_ICON_SIZE)
|
||||
self.ui.addClassifyBtn.clicked.connect(
|
||||
self.__classifyUiContext.addClassify)
|
||||
|
||||
|
@ -145,7 +147,10 @@ class MainWindow(QtWidgets.QMainWindow):
|
|||
self.ui.searchClassifyHistoryCmb.setToolTip("查找分类历史")
|
||||
self.ui.imageScaleSlider.setToolTip("图片缩放")
|
||||
|
||||
def __setToolButton(self, button, tool_tip: str, icon_path: str,
|
||||
def __setToolButton(self,
|
||||
button,
|
||||
tool_tip: str,
|
||||
icon_path: str,
|
||||
icon_size: int):
|
||||
"""设置工具按钮"""
|
||||
button.setToolTip(tool_tip)
|
||||
|
@ -160,9 +165,9 @@ class MainWindow(QtWidgets.QMainWindow):
|
|||
|
||||
self.__libraryAppendMenu.setTitle("导入图像")
|
||||
utils.setMenu(self.__libraryAppendMenu, "导入 image_list 图像",
|
||||
self.importImageListImage)
|
||||
self.importImageListImage)
|
||||
utils.setMenu(self.__libraryAppendMenu, "导入多文件夹图像",
|
||||
self.importDirsImage)
|
||||
self.importDirsImage)
|
||||
self.__appMenu.addMenu(self.__libraryAppendMenu)
|
||||
|
||||
self.__appMenu.addSeparator()
|
||||
|
@ -179,16 +184,16 @@ class MainWindow(QtWidgets.QMainWindow):
|
|||
def __initWaitDialog(self):
|
||||
"""初始化等待对话框"""
|
||||
self.__waitDialogUi.setupUi(self.__waitDialog)
|
||||
self.__waitDialog.setWindowFlags(QtCore.Qt.Dialog
|
||||
| QtCore.Qt.FramelessWindowHint)
|
||||
self.__waitDialog.setWindowFlags(QtCore.Qt.Dialog |
|
||||
QtCore.Qt.FramelessWindowHint)
|
||||
|
||||
def __startWait(self, msg: str):
|
||||
"""开始显示等待对话框"""
|
||||
self.setEnabled(False)
|
||||
self.__waitDialogUi.msgLabel.setText(msg)
|
||||
self.__waitDialog.setWindowFlags(QtCore.Qt.Dialog
|
||||
| QtCore.Qt.FramelessWindowHint
|
||||
| QtCore.Qt.WindowStaysOnTopHint)
|
||||
self.__waitDialog.setWindowFlags(QtCore.Qt.Dialog |
|
||||
QtCore.Qt.FramelessWindowHint |
|
||||
QtCore.Qt.WindowStaysOnTopHint)
|
||||
self.__waitDialog.show()
|
||||
self.__waitDialog.repaint()
|
||||
|
||||
|
@ -196,9 +201,9 @@ class MainWindow(QtWidgets.QMainWindow):
|
|||
"""停止显示等待对话框"""
|
||||
self.setEnabled(True)
|
||||
self.__waitDialogUi.msgLabel.setText("执行完毕!")
|
||||
self.__waitDialog.setWindowFlags(QtCore.Qt.Dialog
|
||||
| QtCore.Qt.FramelessWindowHint
|
||||
| QtCore.Qt.CustomizeWindowHint)
|
||||
self.__waitDialog.setWindowFlags(QtCore.Qt.Dialog |
|
||||
QtCore.Qt.FramelessWindowHint |
|
||||
QtCore.Qt.CustomizeWindowHint)
|
||||
self.__waitDialog.close()
|
||||
|
||||
def __connectSignal(self):
|
||||
|
@ -290,8 +295,8 @@ class MainWindow(QtWidgets.QMainWindow):
|
|||
|
||||
def __importImageListImageThread(self, from_path: str, to_path: str):
|
||||
"""导入 image_list 图像 线程"""
|
||||
count = utils.oneKeyImportFromFile(from_path=from_path,
|
||||
to_path=to_path)
|
||||
count = utils.oneKeyImportFromFile(
|
||||
from_path=from_path, to_path=to_path)
|
||||
if count == None:
|
||||
count = -1
|
||||
self.importImageCount.emit(count)
|
||||
|
@ -308,9 +313,9 @@ class MainWindow(QtWidgets.QMainWindow):
|
|||
return
|
||||
from_mgr = image_list_manager.ImageListManager(from_path)
|
||||
self.__startWait("正在导入图像,请等待。。。")
|
||||
thread = threading.Thread(target=self.__importImageListImageThread,
|
||||
args=(from_mgr.filePath,
|
||||
self.__imageListMgr.filePath))
|
||||
thread = threading.Thread(
|
||||
target=self.__importImageListImageThread,
|
||||
args=(from_mgr.filePath, self.__imageListMgr.filePath))
|
||||
thread.start()
|
||||
|
||||
def __importDirsImageThread(self, from_dir: str, to_image_list_path: str):
|
||||
|
@ -333,21 +338,25 @@ class MainWindow(QtWidgets.QMainWindow):
|
|||
QtWidgets.QMessageBox.information(self, "提示", "打开的目录不存在")
|
||||
return
|
||||
self.__startWait("正在导入图像,请等待。。。")
|
||||
thread = threading.Thread(target=self.__importDirsImageThread,
|
||||
args=(dir_path,
|
||||
self.__imageListMgr.filePath))
|
||||
thread = threading.Thread(
|
||||
target=self.__importDirsImageThread,
|
||||
args=(dir_path, self.__imageListMgr.filePath))
|
||||
thread.start()
|
||||
|
||||
def __newIndexThread(self, index_root_path: str, image_list_path: str,
|
||||
index_method: str, force: bool):
|
||||
def __newIndexThread(self,
|
||||
index_root_path: str,
|
||||
image_list_path: str,
|
||||
index_method: str,
|
||||
force: bool):
|
||||
"""新建重建索引库线程"""
|
||||
try:
|
||||
client = index_http_client.IndexHttpClient(
|
||||
DEFAULT_HOST, DEFAULT_PORT)
|
||||
err_msg = client.new_index(image_list_path=image_list_path,
|
||||
index_root_path=index_root_path,
|
||||
index_method=index_method,
|
||||
force=force)
|
||||
client = index_http_client.IndexHttpClient(self.server_ip,
|
||||
self.server_port)
|
||||
err_msg = client.new_index(
|
||||
image_list_path=image_list_path,
|
||||
index_root_path=index_root_path,
|
||||
index_method=index_method,
|
||||
force=force)
|
||||
if err_msg == None:
|
||||
err_msg = ""
|
||||
self.newIndexMsg.emit(err_msg)
|
||||
|
@ -375,19 +384,20 @@ class MainWindow(QtWidgets.QMainWindow):
|
|||
force = ui.resetCheckBox.isChecked()
|
||||
if result == QtWidgets.QDialog.Accepted:
|
||||
self.__startWait("正在 新建/重建 索引库,请等待。。。")
|
||||
thread = threading.Thread(target=self.__newIndexThread,
|
||||
args=(self.__imageListMgr.dirName,
|
||||
"image_list.txt", index_method,
|
||||
force))
|
||||
thread = threading.Thread(
|
||||
target=self.__newIndexThread,
|
||||
args=(self.__imageListMgr.dirName, "image_list.txt",
|
||||
index_method, force))
|
||||
thread.start()
|
||||
|
||||
def __openIndexThread(self, index_root_path: str, image_list_path: str):
|
||||
"""打开索引库线程"""
|
||||
try:
|
||||
client = index_http_client.IndexHttpClient(
|
||||
DEFAULT_HOST, DEFAULT_PORT)
|
||||
err_msg = client.open_index(index_root_path=index_root_path,
|
||||
image_list_path=image_list_path)
|
||||
client = index_http_client.IndexHttpClient(self.server_ip,
|
||||
self.server_port)
|
||||
err_msg = client.open_index(
|
||||
index_root_path=index_root_path,
|
||||
image_list_path=image_list_path)
|
||||
if err_msg == None:
|
||||
err_msg = ""
|
||||
self.openIndexMsg.emit(err_msg)
|
||||
|
@ -408,18 +418,19 @@ class MainWindow(QtWidgets.QMainWindow):
|
|||
QtWidgets.QMessageBox.information(self, "提示", "请先打开正确的图像库")
|
||||
return
|
||||
self.__startWait("正在打开索引库,请等待。。。")
|
||||
thread = threading.Thread(target=self.__openIndexThread,
|
||||
args=(self.__imageListMgr.dirName,
|
||||
"image_list.txt"))
|
||||
thread = threading.Thread(
|
||||
target=self.__openIndexThread,
|
||||
args=(self.__imageListMgr.dirName, "image_list.txt"))
|
||||
thread.start()
|
||||
|
||||
def __updateIndexThread(self, index_root_path: str, image_list_path: str):
|
||||
"""更新索引库线程"""
|
||||
try:
|
||||
client = index_http_client.IndexHttpClient(
|
||||
DEFAULT_HOST, DEFAULT_PORT)
|
||||
err_msg = client.update_index(image_list_path=image_list_path,
|
||||
index_root_path=index_root_path)
|
||||
client = index_http_client.IndexHttpClient(self.server_ip,
|
||||
self.server_port)
|
||||
err_msg = client.update_index(
|
||||
image_list_path=image_list_path,
|
||||
index_root_path=index_root_path)
|
||||
if err_msg == None:
|
||||
err_msg = ""
|
||||
self.updateIndexMsg.emit(err_msg)
|
||||
|
@ -440,9 +451,9 @@ class MainWindow(QtWidgets.QMainWindow):
|
|||
QtWidgets.QMessageBox.information(self, "提示", "请先打开正确的图像库")
|
||||
return
|
||||
self.__startWait("正在更新索引库,请等待。。。")
|
||||
thread = threading.Thread(target=self.__updateIndexThread,
|
||||
args=(self.__imageListMgr.dirName,
|
||||
"image_list.txt"))
|
||||
thread = threading.Thread(
|
||||
target=self.__updateIndexThread,
|
||||
args=(self.__imageListMgr.dirName, "image_list.txt"))
|
||||
thread.start()
|
||||
|
||||
def searchClassify(self):
|
||||
|
@ -471,9 +482,6 @@ class MainWindow(QtWidgets.QMainWindow):
|
|||
|
||||
def exitApp(self):
|
||||
"""退出应用"""
|
||||
if isinstance(self.server_process, Process):
|
||||
self.server_process.terminate()
|
||||
# os.kill(self.server_pid)
|
||||
sys.exit(0)
|
||||
|
||||
def __setPathBar(self, msg: str):
|
||||
|
|
|
@ -0,0 +1,340 @@
|
|||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
import sys
|
||||
from PyQt5 import QtCore, QtGui, QtWidgets
|
||||
import mod.mainwindow
|
||||
|
||||
from paddleclas.deploy.utils import config, logger
|
||||
from paddleclas.deploy.python.predict_rec import RecPredictor
|
||||
from fastapi import FastAPI
|
||||
import uvicorn
|
||||
import numpy as np
|
||||
import faiss
|
||||
from typing import List
|
||||
import pickle
|
||||
import cv2
|
||||
import socket
|
||||
import json
|
||||
import operator
|
||||
from multiprocessing import Process
|
||||
"""
|
||||
完整的index库如下:
|
||||
root_path/ # 库存储目录
|
||||
|-- image_list.txt # 图像列表,每行:image_path label。由前端生成及修改。后端只读
|
||||
|-- features.pkl # 建库之后,保存的embedding向量,后端生成,前端无需操作
|
||||
|-- images # 图像存储目录,由前端生成及增删查等操作。后端只读
|
||||
| |-- md5.jpg
|
||||
| |-- md5.jpg
|
||||
| |-- ……
|
||||
|-- index # 真正的生成的index库存储目录,后端生成及操作,前端无需操作。
|
||||
| |-- vector.index # faiss生成的索引库
|
||||
| |-- id_map.pkl # 索引文件
|
||||
"""
|
||||
|
||||
|
||||
class ShiTuIndexManager(object):
|
||||
def __init__(self, config):
|
||||
self.root_path = None
|
||||
self.image_list_path = "image_list.txt"
|
||||
self.image_dir = "images"
|
||||
self.index_path = "index/vector.index"
|
||||
self.id_map_path = "index/id_map.pkl"
|
||||
self.features_path = "features.pkl"
|
||||
self.index = None
|
||||
self.id_map = None
|
||||
self.features = None
|
||||
self.config = config
|
||||
self.predictor = RecPredictor(config)
|
||||
|
||||
def _load_pickle(self, path):
|
||||
if os.path.exists(path):
|
||||
return pickle.load(open(path, 'rb'))
|
||||
else:
|
||||
return None
|
||||
|
||||
def _save_pickle(self, path, data):
|
||||
if not os.path.exists(os.path.dirname(path)):
|
||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||
with open(path, 'wb') as fd:
|
||||
pickle.dump(data, fd)
|
||||
|
||||
def _load_index(self):
|
||||
self.index = faiss.read_index(
|
||||
os.path.join(self.root_path, self.index_path))
|
||||
self.id_map = self._load_pickle(
|
||||
os.path.join(self.root_path, self.id_map_path))
|
||||
self.features = self._load_pickle(
|
||||
os.path.join(self.root_path, self.features_path))
|
||||
|
||||
def _save_index(self, index, id_map, features):
|
||||
faiss.write_index(index, os.path.join(self.root_path, self.index_path))
|
||||
self._save_pickle(
|
||||
os.path.join(self.root_path, self.id_map_path), id_map)
|
||||
self._save_pickle(
|
||||
os.path.join(self.root_path, self.features_path), features)
|
||||
|
||||
def _update_path(self, root_path, image_list_path=None):
|
||||
if root_path == self.root_path:
|
||||
pass
|
||||
else:
|
||||
self.root_path = root_path
|
||||
if not os.path.exists(os.path.join(root_path, "index")):
|
||||
os.mkdir(os.path.join(root_path, "index"))
|
||||
if image_list_path is not None:
|
||||
self.image_list_path = image_list_path
|
||||
|
||||
def _cal_featrue(self, image_list):
|
||||
batch_images = []
|
||||
featrures = None
|
||||
cnt = 0
|
||||
for idx, image_path in enumerate(image_list):
|
||||
image = cv2.imread(image_path)
|
||||
if image is None:
|
||||
return "{} is broken or not exist. Stop"
|
||||
else:
|
||||
image = image[:, :, ::-1]
|
||||
batch_images.append(image)
|
||||
cnt += 1
|
||||
if cnt % self.config["Global"]["batch_size"] == 0 or (
|
||||
idx + 1) == len(image_list):
|
||||
if len(batch_images) == 0:
|
||||
continue
|
||||
batch_results = self.predictor.predict(batch_images)
|
||||
featrures = batch_results if featrures is None else np.concatenate(
|
||||
(featrures, batch_results), axis=0)
|
||||
batch_images = []
|
||||
return featrures
|
||||
|
||||
def _split_datafile(self, data_file, image_root):
|
||||
'''
|
||||
data_file: image path and info, which can be splitted by spacer
|
||||
image_root: image path root
|
||||
delimiter: delimiter
|
||||
'''
|
||||
gallery_images = []
|
||||
gallery_docs = []
|
||||
gallery_ids = []
|
||||
with open(data_file, 'r', encoding='utf-8') as f:
|
||||
lines = f.readlines()
|
||||
for _, ori_line in enumerate(lines):
|
||||
line = ori_line.strip().split()
|
||||
text_num = len(line)
|
||||
assert text_num >= 2, f"line({ori_line}) must be splitted into at least 2 parts, but got {text_num}"
|
||||
image_file = os.path.join(image_root, line[0])
|
||||
|
||||
gallery_images.append(image_file)
|
||||
gallery_docs.append(ori_line.strip())
|
||||
gallery_ids.append(os.path.basename(line[0]).split(".")[0])
|
||||
|
||||
return gallery_images, gallery_docs, gallery_ids
|
||||
|
||||
def create_index(self,
|
||||
image_list: str,
|
||||
index_method: str="HNSW32",
|
||||
image_root: str=None):
|
||||
if not os.path.exists(image_list):
|
||||
return "{} is not exist".format(image_list)
|
||||
if index_method.lower() not in ['hnsw32', 'ivf', 'flat']:
|
||||
return "The index method Only support: HNSW32, IVF, Flat"
|
||||
self._update_path(os.path.dirname(image_list), image_list)
|
||||
|
||||
# get image_paths
|
||||
image_root = image_root if image_root is not None else self.root_path
|
||||
gallery_images, gallery_docs, image_ids = self._split_datafile(
|
||||
image_list, image_root)
|
||||
|
||||
# gernerate index
|
||||
if index_method == "IVF":
|
||||
index_method = index_method + str(
|
||||
min(max(int(len(gallery_images) // 32), 2), 65536)) + ",Flat"
|
||||
index = faiss.index_factory(
|
||||
self.config["IndexProcess"]["embedding_size"], index_method,
|
||||
faiss.METRIC_INNER_PRODUCT)
|
||||
self.index = faiss.IndexIDMap2(index)
|
||||
features = self._cal_featrue(gallery_images)
|
||||
self.index.train(features)
|
||||
index_ids = np.arange(0, len(gallery_images)).astype(np.int64)
|
||||
self.index.add_with_ids(features, index_ids)
|
||||
|
||||
self.id_map = dict()
|
||||
for i, d in zip(list(index_ids), gallery_docs):
|
||||
self.id_map[i] = d
|
||||
|
||||
self.features = {
|
||||
"features": features,
|
||||
"index_method": index_method,
|
||||
"image_ids": image_ids,
|
||||
"index_ids": index_ids.tolist()
|
||||
}
|
||||
self._save_index(self.index, self.id_map, self.features)
|
||||
|
||||
def open_index(self, root_path: str, image_list_path: str) -> str:
|
||||
self._update_path(root_path)
|
||||
_, _, image_ids = self._split_datafile(image_list_path, root_path)
|
||||
if os.path.exists(os.path.join(self.root_path, self.index_path)) and \
|
||||
os.path.exists(os.path.join(self.root_path, self.id_map_path)) and \
|
||||
os.path.exists(os.path.join(self.root_path, self.features_path)):
|
||||
self._update_path(root_path)
|
||||
self._load_index()
|
||||
if operator.eq(set(image_ids), set(self.features['image_ids'])):
|
||||
return ""
|
||||
else:
|
||||
return "The image list is different from index, Please update index"
|
||||
else:
|
||||
return "File not exist: features.pkl, vector.index, id_map.pkl"
|
||||
|
||||
def update_index(self, image_list: str, image_root: str=None) -> str:
|
||||
if self.index and self.id_map and self.features:
|
||||
image_paths, image_docs, image_ids = self._split_datafile(
|
||||
image_list, image_root
|
||||
if image_root is not None else self.root_path)
|
||||
|
||||
# for add image
|
||||
add_ids = list(
|
||||
set(image_ids).difference(set(self.features["image_ids"])))
|
||||
add_indexes = [i for i, x in enumerate(image_ids) if x in add_ids]
|
||||
add_image_paths = [image_paths[i] for i in add_indexes]
|
||||
add_image_docs = [image_docs[i] for i in add_indexes]
|
||||
add_image_ids = [image_ids[i] for i in add_indexes]
|
||||
self._add_index(add_image_paths, add_image_docs, add_image_ids)
|
||||
|
||||
# delete images
|
||||
delete_ids = list(
|
||||
set(self.features["image_ids"]).difference(set(image_ids)))
|
||||
self._delete_index(delete_ids)
|
||||
self._save_index(self.index, self.id_map, self.features)
|
||||
return ""
|
||||
else:
|
||||
return "Failed. Please create or open index first"
|
||||
|
||||
def _add_index(self, image_list: List, image_docs: List, image_ids: List):
|
||||
if len(image_ids) == 0:
|
||||
return
|
||||
featrures = self._cal_featrue(image_list)
|
||||
index_ids = (
|
||||
np.arange(0, len(image_list)) + max(self.id_map.keys()) + 1
|
||||
).astype(np.int64)
|
||||
self.index.add_with_ids(featrures, index_ids)
|
||||
|
||||
for i, d in zip(index_ids, image_docs):
|
||||
self.id_map[i] = d
|
||||
|
||||
self.features['features'] = np.concatenate(
|
||||
[self.features['features'], featrures], axis=0)
|
||||
self.features['image_ids'].extend(image_ids)
|
||||
self.features['index_ids'].extend(index_ids.tolist())
|
||||
|
||||
def _delete_index(self, image_ids: List):
|
||||
if len(image_ids) == 0:
|
||||
return
|
||||
indexes = [
|
||||
i for i, x in enumerate(self.features['image_ids'])
|
||||
if x in image_ids
|
||||
]
|
||||
self.features["features"] = np.delete(
|
||||
self.features["features"], indexes, axis=0)
|
||||
self.features["image_ids"] = np.delete(
|
||||
np.asarray(self.features["image_ids"]), indexes, axis=0).tolist()
|
||||
index_ids = np.delete(
|
||||
np.asarray(self.features["index_ids"]), indexes, axis=0).tolist()
|
||||
id_map_values = [self.id_map[i] for i in index_ids]
|
||||
self.index.reset()
|
||||
ids = np.arange(0, len(id_map_values)).astype(np.int64)
|
||||
self.index.add_with_ids(self.features['features'], ids)
|
||||
self.id_map.clear()
|
||||
for i, d in zip(ids, id_map_values):
|
||||
self.id_map[i] = d
|
||||
self.features["index_ids"] = ids
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
@app.get("/new_index")
|
||||
def new_index(image_list_path: str,
|
||||
index_method: str="HNSW32",
|
||||
index_root_path: str=None,
|
||||
force: bool=False):
|
||||
result = ""
|
||||
try:
|
||||
if index_root_path is not None:
|
||||
image_list_path = os.path.join(index_root_path, image_list_path)
|
||||
index_path = os.path.join(index_root_path, "index", "vector.index")
|
||||
id_map_path = os.path.join(index_root_path, "index", "id_map.pkl")
|
||||
|
||||
if not (os.path.exists(index_path) and
|
||||
os.path.exists(id_map_path)) or force:
|
||||
manager.create_index(image_list_path, index_method,
|
||||
index_root_path)
|
||||
else:
|
||||
result = "There alrealy has index in {}".format(index_root_path)
|
||||
except Exception as e:
|
||||
result = e.__str__()
|
||||
data = {"error_message": result}
|
||||
return json.dumps(data).encode()
|
||||
|
||||
|
||||
@app.get("/open_index")
|
||||
def open_index(index_root_path: str, image_list_path: str):
|
||||
result = ""
|
||||
try:
|
||||
image_list_path = os.path.join(index_root_path, image_list_path)
|
||||
result = manager.open_index(index_root_path, image_list_path)
|
||||
except Exception as e:
|
||||
result = e.__str__()
|
||||
|
||||
data = {"error_message": result}
|
||||
return json.dumps(data).encode()
|
||||
|
||||
|
||||
@app.get("/update_index")
|
||||
def update_index(image_list_path: str, index_root_path: str=None):
|
||||
result = ""
|
||||
try:
|
||||
if index_root_path is not None:
|
||||
image_list_path = os.path.join(index_root_path, image_list_path)
|
||||
result = manager.update_index(
|
||||
image_list=image_list_path, image_root=index_root_path)
|
||||
except Exception as e:
|
||||
result = e.__str__()
|
||||
data = {"error_message": result}
|
||||
return json.dumps(data).encode()
|
||||
|
||||
|
||||
def FrontInterface(server_process=None):
|
||||
front = QtWidgets.QApplication([])
|
||||
main_window = mod.mainwindow.MainWindow(process=server_process)
|
||||
main_window.showMaximized()
|
||||
sys.exit(front.exec_())
|
||||
|
||||
|
||||
def Server(app, host, port):
|
||||
uvicorn.run(app, host=host, port=port)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = config.parse_args()
|
||||
model_config = config.get_config(
|
||||
args.config, overrides=args.override, show=True)
|
||||
manager = ShiTuIndexManager(model_config)
|
||||
ip = model_config.get('ip', None)
|
||||
port = model_config.get('port', None)
|
||||
if ip is None or port is None:
|
||||
try:
|
||||
ip = socket.gethostbyname(socket.gethostname())
|
||||
except:
|
||||
ip = '127.0.0.1'
|
||||
port = 8000
|
||||
Server(app, ip, port)
|
|
@ -22,7 +22,7 @@
|
|||
- [2. 使用说明](#2)
|
||||
|
||||
- [2.1 环境安装](#2.1)
|
||||
- [2.2 模型准备](#2.2)
|
||||
- [2.2 模型及数据准备](#2.2)
|
||||
- [2.3运行使用](#2.3)
|
||||
|
||||
- [3.生成文件介绍](#3)
|
||||
|
@ -90,7 +90,7 @@
|
|||
在打开图像库或者新建图像库完成后,可以使用导入图像功能,即导入用户自己生成好的图像库。具体有支持两种导入格式
|
||||
|
||||
- image_list格式:打开具体的`.txt`文件。`.txt`文件中每一行格式: `image_path label`。跟据文件路径及label导入
|
||||
- 多文件夹格式:打开`具体文件夹`,此文件夹下存储多个子文件夹,每个子文件夹名字为`label_name`,每个子文件夹中存储对应的图像数据。
|
||||
- 多文件夹格式:打开`具体文件夹`,此文件夹下存储多个子文件夹,每个子文件夹名字为`label_name`,每个子文件夹中存储对应的图像数据。
|
||||
|
||||
<a name="1.4"></a>
|
||||
|
||||
|
@ -123,13 +123,25 @@
|
|||
pip install fastapi
|
||||
pip install uvicorn
|
||||
pip install pyqt5
|
||||
pip install psutil
|
||||
```
|
||||
|
||||
<a name="2.2"></a>
|
||||
|
||||
### 2.2 模型准备
|
||||
### 2.2 模型及数据准备
|
||||
|
||||
请按照[PP-ShiTu快速体验](../quick_start/quick_start_recognition.md#2.2.1)中下载及准备inference model,并修改好`${PaddleClas}/deploy/configs/inference_drink.yaml`的相关参数。
|
||||
请按照[PP-ShiTu快速体验](../quick_start/quick_start_recognition.md#2.2.1)中下载及准备inference model,并修改好`${PaddleClas}/deploy/configs/inference_drink.yaml`的相关参数,同时准备好数据集。在具体使用时,请替换好自己的数据集及模型文件。
|
||||
|
||||
```shell
|
||||
cd ${PaddleClas}/deploy/shitu_index_manager
|
||||
mkdir models
|
||||
cd models
|
||||
# 下载及解压识别模型
|
||||
wget https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/inference/PP-ShiTuV2/general_PPLCNetV2_base_pretrained_v1.0_infer.tar && tar -xf general_PPLCNetV2_base_pretrained_v1.0_infer.tar
|
||||
cd ..
|
||||
# 下载及解压示例数据集
|
||||
wget https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/data/drink_dataset_v2.0.tar && tar -xf drink_dataset_v2.0.tar
|
||||
```
|
||||
|
||||
<a name="2.3"></a>
|
||||
|
||||
|
@ -139,9 +151,26 @@ pip install pyqt5
|
|||
|
||||
```shell
|
||||
cd ${PaddleClas}/deploy/shitu_index_manager
|
||||
python index_manager.py -c ../configs/inference_drink.yaml
|
||||
cp ../configs/inference_drink.yaml .
|
||||
# 注意如果没有按照2.2中准备数据集及代码,请手动修改inference_drink.yaml,做好适配
|
||||
python index_manager.py -c inference_drink.yaml
|
||||
```
|
||||
|
||||
运行成功后,会自动跳转到工具界面,可以按照如下步骤,生成新的index库。
|
||||
|
||||
1. 点击菜单栏`新建图像库`,会提示打开一个文件夹,此时请创建一个**新的文件夹**,并打开。如在`${PaddleClas}/deploy/shitu_index_manager`下新建一个`drink_index`文件夹
|
||||
2. 导入图像,或者如上面功能介绍,自己手动新增类别和相应的图像,下面介绍两种导入图像方式,操作时,二选一即可。
|
||||
- 点击`导入图像`->`导入image_list图像`,打开`${PaddleClas}/deploy/shitu_index_manager/drink_dataset_v2.0/gallery/drink_label.txt`,此时就可以将`drink_label.txt`中的图像全部导入进来,图像类别就是`drink_label.txt`中记录的类别。
|
||||
- 点击`导入图像`->`导入多文件夹图像`,打开`${PaddleClas}/deploy/shitu_index_manager/drink_dataset_v2.0/gallery/`文件夹,此时就将`gallery`文件夹下,所有子文件夹都导入进来,图像类别就是子文件夹的名字。
|
||||
3. 点击菜单栏中`新建/重建 索引库`,此时就会开始生成索引库。如果图片较多或者使用cpu来进行特征提取,那么耗时会比较长,请耐心等待。
|
||||
4. 生成索引库成功后,会发现在`drink_index`文件夹下生成如[3](#3) 中介绍的文件,此时`index`子文件夹下生出的文件,就是`PP-ShiTu`所使用的索引文件。
|
||||
|
||||
**注意**:
|
||||
|
||||
- 利用此工具生成的index库,如`drink_index`文件夹,请妥善存储。之后,可以继续使用此工具中`打开图像库`功能,打开`drink_index`文件夹,继续对index库进行增删改查操作,具体功能可以查看[功能介绍](#1)。
|
||||
- 打开一个生成好的库,在其上面进行增删改查操作后,请及时保存。保存后并及时使用菜单中`更新索引库`功能,对索引库进行更新
|
||||
- 如果要使用自己的图像库文件,图像生成格式如示例数据格式,生成`image_list.txt`或者多文件夹存储,二选一。
|
||||
|
||||
<a name="3"></a>
|
||||
|
||||
## 3. 生成文件介绍
|
||||
|
@ -150,10 +179,10 @@ python index_manager.py -c ../configs/inference_drink.yaml
|
|||
|
||||
```shell
|
||||
index_root/ # 库存储目录
|
||||
|-- image_list.txt # 图像列表,每行:image_path label。由前端生成及修改,后端只读
|
||||
|-- image_list.txt # 图像列表,每行:image_path label。由前端生成及修改,后端只读
|
||||
|-- images # 图像存储目录,由前端生成及增删查等操作。后端只读
|
||||
| |-- md5.jpg
|
||||
| |-- md5.jpg
|
||||
| |-- md5.jpg
|
||||
| |-- md5.jpg
|
||||
| |-- ……
|
||||
|-- features.pkl # 建库之后,保存的embedding向量,后端生成,前端无需操作
|
||||
|-- index # 真正的生成的index库存储目录,后端生成及操作,前端无需操作。
|
||||
|
@ -192,4 +221,3 @@ index_root/ # 库存储目录
|
|||
- 问题4: 报错 图像与index库不一致
|
||||
|
||||
答:可能用户自己修改了image_list.txt,修改完成后,请及时更新index库,保证其一致。
|
||||
|
||||
|
|
Loading…
Reference in New Issue