commit
957d9ef2e4
|
@ -28,11 +28,11 @@ RecPostProcess: null
|
|||
|
||||
# indexing engine config
|
||||
IndexProcess:
|
||||
index_path: "./recognition_demo_data_v1.0/gallery_cartoon/index/"
|
||||
image_root: "./recognition_demo_data_v1.0/gallery_cartoon/"
|
||||
data_file: "./recognition_demo_data_v1.0/gallery_cartoon/data_file.txt"
|
||||
append_index: False
|
||||
index_method: "HNSW32" # supported: HNSW32, IVF, Flat
|
||||
index_dir: "./recognition_demo_data_v1.1/gallery_cartoon/index/"
|
||||
image_root: "./recognition_demo_data_v1.1/gallery_cartoon/"
|
||||
data_file: "./recognition_demo_data_v1.1/gallery_cartoon/data_file.txt"
|
||||
index_operation: "new" # suported: "append", "remove", "new"
|
||||
delimiter: "\t"
|
||||
dist_type: "IP"
|
||||
pq_size: 100
|
||||
embedding_size: 2048
|
||||
|
|
|
@ -26,11 +26,11 @@ RecPostProcess: null
|
|||
|
||||
# indexing engine config
|
||||
IndexProcess:
|
||||
index_path: "./recognition_demo_data_v1.0/gallery_logo/index/"
|
||||
image_root: "./recognition_demo_data_v1.0/gallery_logo/"
|
||||
data_file: "./recognition_demo_data_v1.0/gallery_logo/data_file.txt"
|
||||
append_index: False
|
||||
index_method: "HNSW32" # supported: HNSW32, IVF, Flat
|
||||
index_dir: "./recognition_demo_data_v1.1/gallery_logo/index/"
|
||||
image_root: "./recognition_demo_data_v1.1/gallery_logo/"
|
||||
data_file: "./recognition_demo_data_v1.1/gallery_logo/data_file.txt"
|
||||
index_operation: "new" # suported: "append", "remove", "new"
|
||||
delimiter: "\t"
|
||||
dist_type: "IP"
|
||||
pq_size: 100
|
||||
embedding_size: 512
|
||||
|
|
|
@ -26,11 +26,11 @@ RecPostProcess: null
|
|||
|
||||
# indexing engine config
|
||||
IndexProcess:
|
||||
index_path: "./recognition_demo_data_v1.0/gallery_product/index"
|
||||
image_root: "./recognition_demo_data_v1.0/gallery_product/"
|
||||
data_file: "./recognition_demo_data_v1.0/gallery_product/data_file.txt"
|
||||
append_index: False
|
||||
index_method: "HNSW32" # supported: HNSW32, IVF, Flat
|
||||
index_dir: "./recognition_demo_data_v1.1/gallery_product/index"
|
||||
image_root: "./recognition_demo_data_v1.1/gallery_product/"
|
||||
data_file: "./recognition_demo_data_v1.1/gallery_product/data_file.txt"
|
||||
index_operation: "new" # suported: "append", "remove", "new"
|
||||
delimiter: "\t"
|
||||
dist_type: "IP"
|
||||
pq_size: 100
|
||||
embedding_size: 512
|
||||
|
|
|
@ -26,11 +26,11 @@ RecPostProcess: null
|
|||
|
||||
# indexing engine config
|
||||
IndexProcess:
|
||||
index_path: "./recognition_demo_data_v1.0/gallery_vehicle/index/"
|
||||
image_root: "./recognition_demo_data_v1.0/gallery_vehicle/"
|
||||
data_file: "./recognition_demo_data_v1.0/gallery_vehicle/data_file.txt"
|
||||
append_index: False
|
||||
index_method: "HNSW32" # supported: HNSW32, IVF, Flat
|
||||
index_dir: "./recognition_demo_data_v1.1/gallery_vehicle/index/"
|
||||
image_root: "./recognition_demo_data_v1.1/gallery_vehicle/"
|
||||
data_file: "./recognition_demo_data_v1.1/gallery_vehicle/data_file.txt"
|
||||
index_operation: "new" # suported: "append", "remove", "new"
|
||||
delimiter: "\t"
|
||||
dist_type: "IP"
|
||||
pq_size: 100
|
||||
embedding_size: 512
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
Global:
|
||||
infer_imgs: "./recognition_demo_data_v1.0/test_cartoon"
|
||||
infer_imgs: "./recognition_demo_data_v1.1/test_cartoon"
|
||||
det_inference_model_dir: "./models/ppyolov2_r50vd_dcn_mainbody_v1.0_infer/"
|
||||
rec_inference_model_dir: "./models/cartoon_rec_ResNet50_iCartoon_v1.0_infer/"
|
||||
rec_nms_thresold: 0.05
|
||||
|
@ -51,8 +51,6 @@ RecPreProcess:
|
|||
RecPostProcess: null
|
||||
|
||||
IndexProcess:
|
||||
index_path: "./recognition_demo_data_v1.0/gallery_cartoon/index/"
|
||||
search_budget: 100
|
||||
index_dir: "./recognition_demo_data_v1.1/gallery_cartoon/index/"
|
||||
return_k: 5
|
||||
dist_type: "IP"
|
||||
score_thres: 0.5
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
Global:
|
||||
infer_imgs: "./recognition_demo_data_v1.0/test_logo"
|
||||
infer_imgs: "./recognition_demo_data_v1.1/test_logo"
|
||||
det_inference_model_dir: "./models/ppyolov2_r50vd_dcn_mainbody_v1.0_infer/"
|
||||
rec_inference_model_dir: "./models/logo_rec_ResNet50_Logo3K_v1.0_infer/"
|
||||
rec_nms_thresold: 0.05
|
||||
|
@ -50,8 +50,6 @@ RecPostProcess: null
|
|||
|
||||
# indexing engine config
|
||||
IndexProcess:
|
||||
index_path: "./recognition_demo_data_v1.0/gallery_logo/index/"
|
||||
search_budget: 100
|
||||
index_dir: "./recognition_demo_data_v1.1/gallery_logo/index/"
|
||||
return_k: 5
|
||||
dist_type: "IP"
|
||||
score_thres: 0.5
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
Global:
|
||||
infer_imgs: "./recognition_demo_data_v1.0/test_product/daoxiangcunjinzhubing_6.jpg"
|
||||
infer_imgs: "./recognition_demo_data_v1.1/test_product/daoxiangcunjinzhubing_6.jpg"
|
||||
det_inference_model_dir: "./models/ppyolov2_r50vd_dcn_mainbody_v1.0_infer"
|
||||
rec_inference_model_dir: "./models/product_ResNet50_vd_aliproduct_v1.0_infer"
|
||||
rec_nms_thresold: 0.05
|
||||
|
@ -50,8 +50,6 @@ RecPostProcess: null
|
|||
|
||||
# indexing engine config
|
||||
IndexProcess:
|
||||
index_path: "./recognition_demo_data_v1.0/gallery_product/index"
|
||||
search_budget: 100
|
||||
index_dir: "./recognition_demo_data_v1.1/gallery_product/index"
|
||||
return_k: 5
|
||||
dist_type: "IP"
|
||||
score_thres: 0.5
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
Global:
|
||||
infer_imgs: "./recognition_demo_data_v1.0/test_vehicle/"
|
||||
infer_imgs: "./recognition_demo_data_v1.1/test_vehicle/"
|
||||
det_inference_model_dir: "./models/ppyolov2_r50vd_dcn_mainbody_v1.0_infer/"
|
||||
rec_inference_model_dir: "./models/vehicle_cls_ResNet50_CompCars_v1.0_infer/"
|
||||
rec_nms_thresold: 0.05
|
||||
|
@ -52,8 +52,6 @@ RecPostProcess: null
|
|||
|
||||
# indexing engine config
|
||||
IndexProcess:
|
||||
index_path: "./recognition_demo_data_v1.0/gallery_vehicle/index/"
|
||||
search_budget: 100
|
||||
index_dir: "./recognition_demo_data_v1.1/gallery_vehicle/index/"
|
||||
return_k: 5
|
||||
dist_type: "IP"
|
||||
score_thres: 0.5
|
||||
|
|
|
@ -17,13 +17,13 @@ import sys
|
|||
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(os.path.abspath(os.path.join(__dir__, '../')))
|
||||
|
||||
import copy
|
||||
import cv2
|
||||
import faiss
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
import pickle
|
||||
|
||||
from python.predict_rec import RecPredictor
|
||||
from vector_search import Graph_Index
|
||||
|
||||
from utils import logger
|
||||
from utils import config
|
||||
|
@ -31,9 +31,9 @@ from utils import config
|
|||
|
||||
def split_datafile(data_file, image_root, delimiter="\t"):
|
||||
'''
|
||||
data_file: image path and info, which can be splitted by spacer
|
||||
data_file: image path and info, which can be splitted by spacer
|
||||
image_root: image path root
|
||||
delimiter: delimiter
|
||||
delimiter: delimiter
|
||||
'''
|
||||
gallery_images = []
|
||||
gallery_docs = []
|
||||
|
@ -45,9 +45,8 @@ def split_datafile(data_file, image_root, delimiter="\t"):
|
|||
assert text_num >= 2, f"line({ori_line}) must be splitted into at least 2 parts, but got {text_num}"
|
||||
image_file = os.path.join(image_root, line[0])
|
||||
|
||||
image_doc = line[1]
|
||||
gallery_images.append(image_file)
|
||||
gallery_docs.append(image_doc)
|
||||
gallery_docs.append(ori_line.strip())
|
||||
|
||||
return gallery_images, gallery_docs
|
||||
|
||||
|
@ -64,9 +63,91 @@ class GalleryBuilder(object):
|
|||
'''
|
||||
build index from scratch
|
||||
'''
|
||||
operation_method = config.get("index_operation", "new").lower()
|
||||
|
||||
gallery_images, gallery_docs = split_datafile(
|
||||
config['data_file'], config['image_root'], config['delimiter'])
|
||||
|
||||
# when remove data in index, do not need extract fatures
|
||||
if operation_method != "remove":
|
||||
gallery_features = self._extract_features(gallery_images, config)
|
||||
|
||||
assert operation_method in [
|
||||
"new", "remove", "append"
|
||||
], "Only append, remove and new operation are supported"
|
||||
|
||||
# vector.index: faiss index file
|
||||
# id_map.pkl: use this file to map id to image_doc
|
||||
if operation_method in ["remove", "append"]:
|
||||
# if remove or append, vector.index and id_map.pkl must exist
|
||||
assert os.path.join(
|
||||
config["index_dir"], "vector.index"
|
||||
), "The vector.index dose not exist in {} when 'index_operation' is not None".format(
|
||||
config["index_dir"])
|
||||
assert os.path.join(
|
||||
config["index_dir"], "id_map.pkl"
|
||||
), "The id_map.pkl dose not exist in {} when 'index_operation' is not None".format(
|
||||
config["index_dir"])
|
||||
index = faiss.read_index(
|
||||
os.path.join(config["index_dir"], "vector.index"))
|
||||
with open(os.path.join(config["index_dir"], "id_map.pkl"),
|
||||
'rb') as fd:
|
||||
ids = pickle.load(fd)
|
||||
assert index.ntotal == len(ids.keys(
|
||||
)), "data number in index is not equal in in id_map"
|
||||
else:
|
||||
if not os.path.exists(config["index_dir"]):
|
||||
os.makedirs(config["index_dir"], exist_ok=True)
|
||||
index_method = config.get("index_method", "HNSW32")
|
||||
|
||||
# if IVF method, cal ivf number automaticlly
|
||||
if index_method == "IVF":
|
||||
index_method = index_method + str(
|
||||
min(int(len(gallery_images) // 8), 65536)) + ",Flat"
|
||||
dist_type = faiss.METRIC_INNER_PRODUCT if config[
|
||||
"dist_type"] == "IP" else faiss.METRIC_L2
|
||||
index = faiss.index_factory(config["embedding_size"], index_method,
|
||||
dist_type)
|
||||
index = faiss.IndexIDMap2(index)
|
||||
ids = {}
|
||||
|
||||
if config["index_method"] == "HNSW32":
|
||||
logger.warning(
|
||||
"The HNSW32 method dose not support 'remove' operation")
|
||||
|
||||
if operation_method != "remove":
|
||||
# calculate id for new data
|
||||
start_id = max(ids.keys()) + 1 if ids else 0
|
||||
ids_now = (
|
||||
np.arange(0, len(gallery_images)) + start_id).astype(np.int64)
|
||||
|
||||
# only train when new index file
|
||||
if operation_method == "new":
|
||||
index.train(gallery_features)
|
||||
index.add_with_ids(gallery_features, ids_now)
|
||||
|
||||
for i, d in zip(list(ids_now), gallery_docs):
|
||||
ids[i] = d
|
||||
else:
|
||||
if config["index_method"] == "HNSW32":
|
||||
raise RuntimeError(
|
||||
"The index_method: HNSW32 dose not support 'remove' operation"
|
||||
)
|
||||
# remove ids in id_map, remove index data in faiss index
|
||||
remove_ids = list(
|
||||
filter(lambda k: ids.get(k) in gallery_docs, ids.keys()))
|
||||
remove_ids = np.asarray(remove_ids)
|
||||
index.remove_ids(remove_ids)
|
||||
for k in remove_ids:
|
||||
del ids[k]
|
||||
|
||||
# store faiss index file and id_map file
|
||||
faiss.write_index(index,
|
||||
os.path.join(config["index_dir"], "vector.index"))
|
||||
with open(os.path.join(config["index_dir"], "id_map.pkl"), 'wb') as fd:
|
||||
pickle.dump(ids, fd)
|
||||
|
||||
def _extract_features(self, gallery_images, config):
|
||||
# extract gallery features
|
||||
gallery_features = np.zeros(
|
||||
[len(gallery_images), config['embedding_size']], dtype=np.float32)
|
||||
|
@ -91,19 +172,11 @@ class GalleryBuilder(object):
|
|||
rec_feat = self.rec_predictor.predict(batch_img)
|
||||
gallery_features[-len(batch_img):, :] = rec_feat
|
||||
batch_img = []
|
||||
|
||||
# train index
|
||||
self.Searcher = Graph_Index(dist_type=config['dist_type'])
|
||||
self.Searcher.build(
|
||||
gallery_vectors=gallery_features,
|
||||
gallery_docs=gallery_docs,
|
||||
pq_size=config['pq_size'],
|
||||
index_path=config['index_path'],
|
||||
append_index=config["append_index"])
|
||||
return gallery_features
|
||||
|
||||
|
||||
def main(config):
|
||||
system_builder = GalleryBuilder(config)
|
||||
GalleryBuilder(config)
|
||||
return
|
||||
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
|
@ -20,10 +20,11 @@ sys.path.append(os.path.abspath(os.path.join(__dir__, '../')))
|
|||
import copy
|
||||
import cv2
|
||||
import numpy as np
|
||||
import faiss
|
||||
import pickle
|
||||
|
||||
from python.predict_rec import RecPredictor
|
||||
from python.predict_det import DetPredictor
|
||||
from vector_search import Graph_Index
|
||||
|
||||
from utils import logger
|
||||
from utils import config
|
||||
|
@ -40,11 +41,16 @@ class SystemPredictor(object):
|
|||
|
||||
assert 'IndexProcess' in config.keys(), "Index config not found ... "
|
||||
self.return_k = self.config['IndexProcess']['return_k']
|
||||
self.search_budget = self.config['IndexProcess']['search_budget']
|
||||
|
||||
self.Searcher = Graph_Index(
|
||||
dist_type=config['IndexProcess']['dist_type'])
|
||||
self.Searcher.load(config['IndexProcess']['index_path'])
|
||||
index_dir = self.config["IndexProcess"]["index_dir"]
|
||||
assert os.path.exists(os.path.join(
|
||||
index_dir, "vector.index")), "vector.index not found ..."
|
||||
assert os.path.exists(os.path.join(
|
||||
index_dir, "id_map.pkl")), "id_map.pkl not found ... "
|
||||
self.Searcher = faiss.read_index(
|
||||
os.path.join(index_dir, "vector.index"))
|
||||
with open(os.path.join(index_dir, "id_map.pkl"), "rb") as fd:
|
||||
self.id_map = pickle.load(fd)
|
||||
|
||||
def append_self(self, results, shape):
|
||||
results.append({
|
||||
|
@ -98,14 +104,11 @@ class SystemPredictor(object):
|
|||
crop_img = img[ymin:ymax, xmin:xmax, :].copy()
|
||||
rec_results = self.rec_predictor.predict(crop_img)
|
||||
preds["bbox"] = [xmin, ymin, xmax, ymax]
|
||||
scores, docs = self.Searcher.search(
|
||||
query=rec_results,
|
||||
return_k=self.return_k,
|
||||
search_budget=self.search_budget)
|
||||
scores, docs = self.Searcher.search(rec_results, self.return_k)
|
||||
# just top-1 result will be returned for the final
|
||||
if scores[0] >= self.config["IndexProcess"]["score_thres"]:
|
||||
preds["rec_docs"] = docs[0]
|
||||
preds["rec_scores"] = scores[0]
|
||||
if scores[0][0] >= self.config["IndexProcess"]["score_thres"]:
|
||||
preds["rec_docs"] = self.id_map[docs[0][0]].split()[1]
|
||||
preds["rec_scores"] = scores[0][0]
|
||||
output.append(preds)
|
||||
|
||||
# st5: nms to the final results to avoid fetching duplicate results
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# 向量检索
|
||||
|
||||
|
||||
**注意**:由于系统适配性问题,在新版本中,此检索算法将被废弃。新版本中将使用[faiss](https://github.com/facebookresearch/faiss),整体检索的过程保持不变,但建立索引及检索时的yaml文件有所修改。
|
||||
## 1. 简介
|
||||
|
||||
一些垂域识别任务(如车辆、商品等)需要识别的类别数较大,往往采用基于检索的方式,通过查询向量与底库向量进行快速的最近邻搜索,获得匹配的预测类别。向量检索模块提供基础的近似最近邻搜索算法,基于百度自研的Möbius算法,一种基于图的近似最近邻搜索算法,用于最大内积搜索 (MIPS)。 该模块提供python接口,支持numpy和 tensor类型向量,支持L2和Inner Product距离计算。
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
# Vector search
|
||||
|
||||
**Attention**: Due to the system adaptability problem, this retrieval algorithm will be abandoned in the new version. [faiss](https://github.com/facebookresearch/faiss) will be used in the new version. The use process of the overall retrieval system base will remain unchanged, but the yaml files for build indexes and retrieval will be modified.
|
||||
|
||||
## 1. Introduction
|
||||
|
||||
Some vertical domain recognition tasks (e.g., vehicles, commodities, etc.) require a large number of recognized categories, and often use a retrieval-based approach to obtain matching predicted categories by performing a fast nearest neighbor search with query vectors and underlying library vectors. The vector search module provides the basic approximate nearest neighbor search algorithm based on Baidu's self-developed Möbius algorithm, a graph-based approximate nearest neighbor search algorithm for maximum inner product search (MIPS). This module provides python interface, supports numpy and tensor type vectors, and supports L2 and Inner Product distance calculation.
|
||||
|
@ -57,7 +59,7 @@ brew install gcc
|
|||
1. If prompted with `Error: Running Homebrew as root is extremely dangerous and no longer supported... `, refer to this [link](https://jingyan.baidu.com/article/e52e3615057a2840c60c519c.html)
|
||||
2. If prompted with `Error: Failure while executing; tar --extract --no-same-owner --file... `, refer to this [link](https://blog.csdn.net/Dawn510/article/details/117787358).
|
||||
|
||||
After installation the compiled executable is copied under /usr/local/bin, look at the gcc in this folder:
|
||||
After installation the compiled executable is copied under /usr/local/bin, look at the gcc in this folder:
|
||||
|
||||
```
|
||||
ls /usr/local/bin/gcc*
|
||||
|
|
|
@ -40,11 +40,11 @@ The detection model with the recognition inference model for the 4 directions (L
|
|||
| Logo Recognition Model | Logo Scenario | [Model Download Link](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/inference/logo_rec_ResNet50_Logo3K_v1.0_infer.tar) | [inference_logo.yaml](../../../deploy/configs/inference_logo.yaml) | [build_logo.yaml](../../../deploy/configs/build_logo.yaml) |
|
||||
| Cartoon Face Recognition Model| Cartoon Face Scenario | [Model Download Link](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/inference/cartoon_rec_ResNet50_iCartoon_v1.0_infer.tar) | [inference_cartoon.yaml](../../../deploy/configs/inference_cartoon.yaml) | [build_cartoon.yaml](../../../deploy/configs/build_cartoon.yaml) |
|
||||
| Vehicle Fine-Grained Classfication Model | Vehicle Scenario | [Model Download Link](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/inference/vehicle_cls_ResNet50_CompCars_v1.0_infer.tar) | [inference_vehicle.yaml](../../../deploy/configs/inference_vehicle.yaml) | [build_vehicle.yaml](../../../deploy/configs/build_vehicle.yaml) |
|
||||
| Product Recignition Model | Product Scenario | [Model Download Link](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/inference/product_ResNet50_vd_Inshop_v1.0_infer.tar) | [inference_product.yaml](../../../deploy/configs/inference_product.yaml) | [build_product.yaml](../../../deploy/configs/build_product.yaml) |
|
||||
| Product Recignition Model | Product Scenario | [Model Download Link](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/inference/product_ResNet50_vd_aliproduct_v1.0_infer.tar) | [inference_product.yaml](../../../deploy/configs/inference_product.yaml) | [build_product.yaml](../../../deploy/configs/build_product.yaml) |
|
||||
| Vehicle ReID Model | Vehicle ReID Scenario | [Model Download Link](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/inference/vehicle_reid_ResNet50_VERIWild_v1.0_infer.tar) | - | - |
|
||||
|
||||
|
||||
Demo data in this tutorial can be downloaded here: [download link](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/data/recognition_demo_data_en_v1.0.tar).
|
||||
Demo data in this tutorial can be downloaded here: [download link](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/data/recognition_demo_data_en_v1.1.tar).
|
||||
|
||||
|
||||
**Attention**
|
||||
|
|
|
@ -44,7 +44,7 @@
|
|||
| 车辆ReID模型 | 车辆ReID场景 | [模型下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/inference/vehicle_reid_ResNet50_VERIWild_v1.0_infer.tar) | - | - |
|
||||
|
||||
|
||||
本章节demo数据下载地址如下: [数据下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/data/recognition_demo_data_v1.0.tar)。
|
||||
本章节demo数据下载地址如下: [数据下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/data/recognition_demo_data_v1.1.tar)。
|
||||
|
||||
|
||||
**注意**
|
||||
|
|
|
@ -8,3 +8,4 @@ visualdl >= 2.0.0b
|
|||
scipy
|
||||
scikit-learn==0.23.2
|
||||
gast==0.3.3
|
||||
faiss-cpu==1.7.1
|
||||
|
|
Loading…
Reference in New Issue