diff --git a/deploy/configs/build_cartoon.yaml b/deploy/configs/build_cartoon.yaml index c73279801..f9ec401f0 100644 --- a/deploy/configs/build_cartoon.yaml +++ b/deploy/configs/build_cartoon.yaml @@ -28,11 +28,11 @@ RecPostProcess: null # indexing engine config IndexProcess: - index_path: "./recognition_demo_data_v1.0/gallery_cartoon/index/" - image_root: "./recognition_demo_data_v1.0/gallery_cartoon/" - data_file: "./recognition_demo_data_v1.0/gallery_cartoon/data_file.txt" - append_index: False + index_method: "HNSW32" # supported: HNSW32, IVF, Flat + index_dir: "./recognition_demo_data_v1.1/gallery_cartoon/index/" + image_root: "./recognition_demo_data_v1.1/gallery_cartoon/" + data_file: "./recognition_demo_data_v1.1/gallery_cartoon/data_file.txt" + index_operation: "new" # suported: "append", "remove", "new" delimiter: "\t" dist_type: "IP" - pq_size: 100 embedding_size: 2048 diff --git a/deploy/configs/build_logo.yaml b/deploy/configs/build_logo.yaml index 5be17ed97..b999560a7 100644 --- a/deploy/configs/build_logo.yaml +++ b/deploy/configs/build_logo.yaml @@ -26,11 +26,11 @@ RecPostProcess: null # indexing engine config IndexProcess: - index_path: "./recognition_demo_data_v1.0/gallery_logo/index/" - image_root: "./recognition_demo_data_v1.0/gallery_logo/" - data_file: "./recognition_demo_data_v1.0/gallery_logo/data_file.txt" - append_index: False + index_method: "HNSW32" # supported: HNSW32, IVF, Flat + index_dir: "./recognition_demo_data_v1.1/gallery_logo/index/" + image_root: "./recognition_demo_data_v1.1/gallery_logo/" + data_file: "./recognition_demo_data_v1.1/gallery_logo/data_file.txt" + index_operation: "new" # suported: "append", "remove", "new" delimiter: "\t" dist_type: "IP" - pq_size: 100 embedding_size: 512 diff --git a/deploy/configs/build_product.yaml b/deploy/configs/build_product.yaml index 59e3b29ba..f9b03a7f5 100644 --- a/deploy/configs/build_product.yaml +++ b/deploy/configs/build_product.yaml @@ -26,11 +26,11 @@ RecPostProcess: null # indexing engine config IndexProcess: - index_path: "./recognition_demo_data_v1.0/gallery_product/index" - image_root: "./recognition_demo_data_v1.0/gallery_product/" - data_file: "./recognition_demo_data_v1.0/gallery_product/data_file.txt" - append_index: False + index_method: "HNSW32" # supported: HNSW32, IVF, Flat + index_dir: "./recognition_demo_data_v1.1/gallery_product/index" + image_root: "./recognition_demo_data_v1.1/gallery_product/" + data_file: "./recognition_demo_data_v1.1/gallery_product/data_file.txt" + index_operation: "new" # suported: "append", "remove", "new" delimiter: "\t" dist_type: "IP" - pq_size: 100 embedding_size: 512 diff --git a/deploy/configs/build_vehicle.yaml b/deploy/configs/build_vehicle.yaml index be095f4e1..b737c2b86 100644 --- a/deploy/configs/build_vehicle.yaml +++ b/deploy/configs/build_vehicle.yaml @@ -26,11 +26,11 @@ RecPostProcess: null # indexing engine config IndexProcess: - index_path: "./recognition_demo_data_v1.0/gallery_vehicle/index/" - image_root: "./recognition_demo_data_v1.0/gallery_vehicle/" - data_file: "./recognition_demo_data_v1.0/gallery_vehicle/data_file.txt" - append_index: False + index_method: "HNSW32" # supported: HNSW32, IVF, Flat + index_dir: "./recognition_demo_data_v1.1/gallery_vehicle/index/" + image_root: "./recognition_demo_data_v1.1/gallery_vehicle/" + data_file: "./recognition_demo_data_v1.1/gallery_vehicle/data_file.txt" + index_operation: "new" # suported: "append", "remove", "new" delimiter: "\t" dist_type: "IP" - pq_size: 100 embedding_size: 512 diff --git a/deploy/configs/inference_cartoon.yaml b/deploy/configs/inference_cartoon.yaml index fb3455302..83d2b246a 100644 --- a/deploy/configs/inference_cartoon.yaml +++ b/deploy/configs/inference_cartoon.yaml @@ -1,5 +1,5 @@ Global: - infer_imgs: "./recognition_demo_data_v1.0/test_cartoon" + infer_imgs: "./recognition_demo_data_v1.1/test_cartoon" det_inference_model_dir: "./models/ppyolov2_r50vd_dcn_mainbody_v1.0_infer/" rec_inference_model_dir: "./models/cartoon_rec_ResNet50_iCartoon_v1.0_infer/" rec_nms_thresold: 0.05 @@ -51,8 +51,6 @@ RecPreProcess: RecPostProcess: null IndexProcess: - index_path: "./recognition_demo_data_v1.0/gallery_cartoon/index/" - search_budget: 100 + index_dir: "./recognition_demo_data_v1.1/gallery_cartoon/index/" return_k: 5 - dist_type: "IP" score_thres: 0.5 diff --git a/deploy/configs/inference_logo.yaml b/deploy/configs/inference_logo.yaml index 8be5c33ad..a89d90be9 100644 --- a/deploy/configs/inference_logo.yaml +++ b/deploy/configs/inference_logo.yaml @@ -1,5 +1,5 @@ Global: - infer_imgs: "./recognition_demo_data_v1.0/test_logo" + infer_imgs: "./recognition_demo_data_v1.1/test_logo" det_inference_model_dir: "./models/ppyolov2_r50vd_dcn_mainbody_v1.0_infer/" rec_inference_model_dir: "./models/logo_rec_ResNet50_Logo3K_v1.0_infer/" rec_nms_thresold: 0.05 @@ -50,8 +50,6 @@ RecPostProcess: null # indexing engine config IndexProcess: - index_path: "./recognition_demo_data_v1.0/gallery_logo/index/" - search_budget: 100 + index_dir: "./recognition_demo_data_v1.1/gallery_logo/index/" return_k: 5 - dist_type: "IP" score_thres: 0.5 diff --git a/deploy/configs/inference_product.yaml b/deploy/configs/inference_product.yaml index 871d55d55..70f3c2fce 100644 --- a/deploy/configs/inference_product.yaml +++ b/deploy/configs/inference_product.yaml @@ -1,5 +1,5 @@ Global: - infer_imgs: "./recognition_demo_data_v1.0/test_product/daoxiangcunjinzhubing_6.jpg" + infer_imgs: "./recognition_demo_data_v1.1/test_product/daoxiangcunjinzhubing_6.jpg" det_inference_model_dir: "./models/ppyolov2_r50vd_dcn_mainbody_v1.0_infer" rec_inference_model_dir: "./models/product_ResNet50_vd_aliproduct_v1.0_infer" rec_nms_thresold: 0.05 @@ -50,8 +50,6 @@ RecPostProcess: null # indexing engine config IndexProcess: - index_path: "./recognition_demo_data_v1.0/gallery_product/index" - search_budget: 100 + index_dir: "./recognition_demo_data_v1.1/gallery_product/index" return_k: 5 - dist_type: "IP" score_thres: 0.5 diff --git a/deploy/configs/inference_vehicle.yaml b/deploy/configs/inference_vehicle.yaml index 8edcb8d5d..0152d39fb 100644 --- a/deploy/configs/inference_vehicle.yaml +++ b/deploy/configs/inference_vehicle.yaml @@ -1,5 +1,5 @@ Global: - infer_imgs: "./recognition_demo_data_v1.0/test_vehicle/" + infer_imgs: "./recognition_demo_data_v1.1/test_vehicle/" det_inference_model_dir: "./models/ppyolov2_r50vd_dcn_mainbody_v1.0_infer/" rec_inference_model_dir: "./models/vehicle_cls_ResNet50_CompCars_v1.0_infer/" rec_nms_thresold: 0.05 @@ -52,8 +52,6 @@ RecPostProcess: null # indexing engine config IndexProcess: - index_path: "./recognition_demo_data_v1.0/gallery_vehicle/index/" - search_budget: 100 + index_dir: "./recognition_demo_data_v1.1/gallery_vehicle/index/" return_k: 5 - dist_type: "IP" score_thres: 0.5 diff --git a/deploy/python/build_gallery.py b/deploy/python/build_gallery.py index a7297366d..8412f99f2 100644 --- a/deploy/python/build_gallery.py +++ b/deploy/python/build_gallery.py @@ -17,13 +17,13 @@ import sys __dir__ = os.path.dirname(os.path.abspath(__file__)) sys.path.append(os.path.abspath(os.path.join(__dir__, '../'))) -import copy import cv2 +import faiss import numpy as np from tqdm import tqdm +import pickle from python.predict_rec import RecPredictor -from vector_search import Graph_Index from utils import logger from utils import config @@ -31,9 +31,9 @@ from utils import config def split_datafile(data_file, image_root, delimiter="\t"): ''' - data_file: image path and info, which can be splitted by spacer + data_file: image path and info, which can be splitted by spacer image_root: image path root - delimiter: delimiter + delimiter: delimiter ''' gallery_images = [] gallery_docs = [] @@ -45,9 +45,8 @@ def split_datafile(data_file, image_root, delimiter="\t"): assert text_num >= 2, f"line({ori_line}) must be splitted into at least 2 parts, but got {text_num}" image_file = os.path.join(image_root, line[0]) - image_doc = line[1] gallery_images.append(image_file) - gallery_docs.append(image_doc) + gallery_docs.append(ori_line.strip()) return gallery_images, gallery_docs @@ -64,9 +63,91 @@ class GalleryBuilder(object): ''' build index from scratch ''' + operation_method = config.get("index_operation", "new").lower() + gallery_images, gallery_docs = split_datafile( config['data_file'], config['image_root'], config['delimiter']) + # when remove data in index, do not need extract fatures + if operation_method != "remove": + gallery_features = self._extract_features(gallery_images, config) + + assert operation_method in [ + "new", "remove", "append" + ], "Only append, remove and new operation are supported" + + # vector.index: faiss index file + # id_map.pkl: use this file to map id to image_doc + if operation_method in ["remove", "append"]: + # if remove or append, vector.index and id_map.pkl must exist + assert os.path.join( + config["index_dir"], "vector.index" + ), "The vector.index dose not exist in {} when 'index_operation' is not None".format( + config["index_dir"]) + assert os.path.join( + config["index_dir"], "id_map.pkl" + ), "The id_map.pkl dose not exist in {} when 'index_operation' is not None".format( + config["index_dir"]) + index = faiss.read_index( + os.path.join(config["index_dir"], "vector.index")) + with open(os.path.join(config["index_dir"], "id_map.pkl"), + 'rb') as fd: + ids = pickle.load(fd) + assert index.ntotal == len(ids.keys( + )), "data number in index is not equal in in id_map" + else: + if not os.path.exists(config["index_dir"]): + os.makedirs(config["index_dir"], exist_ok=True) + index_method = config.get("index_method", "HNSW32") + + # if IVF method, cal ivf number automaticlly + if index_method == "IVF": + index_method = index_method + str( + min(int(len(gallery_images) // 8), 65536)) + ",Flat" + dist_type = faiss.METRIC_INNER_PRODUCT if config[ + "dist_type"] == "IP" else faiss.METRIC_L2 + index = faiss.index_factory(config["embedding_size"], index_method, + dist_type) + index = faiss.IndexIDMap2(index) + ids = {} + + if config["index_method"] == "HNSW32": + logger.warning( + "The HNSW32 method dose not support 'remove' operation") + + if operation_method != "remove": + # calculate id for new data + start_id = max(ids.keys()) + 1 if ids else 0 + ids_now = ( + np.arange(0, len(gallery_images)) + start_id).astype(np.int64) + + # only train when new index file + if operation_method == "new": + index.train(gallery_features) + index.add_with_ids(gallery_features, ids_now) + + for i, d in zip(list(ids_now), gallery_docs): + ids[i] = d + else: + if config["index_method"] == "HNSW32": + raise RuntimeError( + "The index_method: HNSW32 dose not support 'remove' operation" + ) + # remove ids in id_map, remove index data in faiss index + remove_ids = list( + filter(lambda k: ids.get(k) in gallery_docs, ids.keys())) + remove_ids = np.asarray(remove_ids) + index.remove_ids(remove_ids) + for k in remove_ids: + del ids[k] + + # store faiss index file and id_map file + faiss.write_index(index, + os.path.join(config["index_dir"], "vector.index")) + with open(os.path.join(config["index_dir"], "id_map.pkl"), 'wb') as fd: + pickle.dump(ids, fd) + + def _extract_features(self, gallery_images, config): # extract gallery features gallery_features = np.zeros( [len(gallery_images), config['embedding_size']], dtype=np.float32) @@ -91,19 +172,11 @@ class GalleryBuilder(object): rec_feat = self.rec_predictor.predict(batch_img) gallery_features[-len(batch_img):, :] = rec_feat batch_img = [] - - # train index - self.Searcher = Graph_Index(dist_type=config['dist_type']) - self.Searcher.build( - gallery_vectors=gallery_features, - gallery_docs=gallery_docs, - pq_size=config['pq_size'], - index_path=config['index_path'], - append_index=config["append_index"]) + return gallery_features def main(config): - system_builder = GalleryBuilder(config) + GalleryBuilder(config) return diff --git a/deploy/python/predict_system.py b/deploy/python/predict_system.py index 3f5d63a81..79c1ea703 100644 --- a/deploy/python/predict_system.py +++ b/deploy/python/predict_system.py @@ -1,6 +1,6 @@ # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # -# Licensed under the Apache License, Version 2.0 (the "License"); +# Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # @@ -20,10 +20,11 @@ sys.path.append(os.path.abspath(os.path.join(__dir__, '../'))) import copy import cv2 import numpy as np +import faiss +import pickle from python.predict_rec import RecPredictor from python.predict_det import DetPredictor -from vector_search import Graph_Index from utils import logger from utils import config @@ -40,11 +41,16 @@ class SystemPredictor(object): assert 'IndexProcess' in config.keys(), "Index config not found ... " self.return_k = self.config['IndexProcess']['return_k'] - self.search_budget = self.config['IndexProcess']['search_budget'] - self.Searcher = Graph_Index( - dist_type=config['IndexProcess']['dist_type']) - self.Searcher.load(config['IndexProcess']['index_path']) + index_dir = self.config["IndexProcess"]["index_dir"] + assert os.path.exists(os.path.join( + index_dir, "vector.index")), "vector.index not found ..." + assert os.path.exists(os.path.join( + index_dir, "id_map.pkl")), "id_map.pkl not found ... " + self.Searcher = faiss.read_index( + os.path.join(index_dir, "vector.index")) + with open(os.path.join(index_dir, "id_map.pkl"), "rb") as fd: + self.id_map = pickle.load(fd) def append_self(self, results, shape): results.append({ @@ -98,14 +104,11 @@ class SystemPredictor(object): crop_img = img[ymin:ymax, xmin:xmax, :].copy() rec_results = self.rec_predictor.predict(crop_img) preds["bbox"] = [xmin, ymin, xmax, ymax] - scores, docs = self.Searcher.search( - query=rec_results, - return_k=self.return_k, - search_budget=self.search_budget) + scores, docs = self.Searcher.search(rec_results, self.return_k) # just top-1 result will be returned for the final - if scores[0] >= self.config["IndexProcess"]["score_thres"]: - preds["rec_docs"] = docs[0] - preds["rec_scores"] = scores[0] + if scores[0][0] >= self.config["IndexProcess"]["score_thres"]: + preds["rec_docs"] = self.id_map[docs[0][0]].split()[1] + preds["rec_scores"] = scores[0][0] output.append(preds) # st5: nms to the final results to avoid fetching duplicate results diff --git a/deploy/vector_search/README.md b/deploy/vector_search/README.md index a921c2345..afa1dc282 100644 --- a/deploy/vector_search/README.md +++ b/deploy/vector_search/README.md @@ -1,6 +1,6 @@ # 向量检索 - +**注意**:由于系统适配性问题,在新版本中,此检索算法将被废弃。新版本中将使用[faiss](https://github.com/facebookresearch/faiss),整体检索的过程保持不变,但建立索引及检索时的yaml文件有所修改。 ## 1. 简介 一些垂域识别任务(如车辆、商品等)需要识别的类别数较大,往往采用基于检索的方式,通过查询向量与底库向量进行快速的最近邻搜索,获得匹配的预测类别。向量检索模块提供基础的近似最近邻搜索算法,基于百度自研的Möbius算法,一种基于图的近似最近邻搜索算法,用于最大内积搜索 (MIPS)。 该模块提供python接口,支持numpy和 tensor类型向量,支持L2和Inner Product距离计算。 diff --git a/deploy/vector_search/README_en.md b/deploy/vector_search/README_en.md index 20b253948..aecadfd95 100644 --- a/deploy/vector_search/README_en.md +++ b/deploy/vector_search/README_en.md @@ -1,5 +1,7 @@ # Vector search +**Attention**: Due to the system adaptability problem, this retrieval algorithm will be abandoned in the new version. [faiss](https://github.com/facebookresearch/faiss) will be used in the new version. The use process of the overall retrieval system base will remain unchanged, but the yaml files for build indexes and retrieval will be modified. + ## 1. Introduction Some vertical domain recognition tasks (e.g., vehicles, commodities, etc.) require a large number of recognized categories, and often use a retrieval-based approach to obtain matching predicted categories by performing a fast nearest neighbor search with query vectors and underlying library vectors. The vector search module provides the basic approximate nearest neighbor search algorithm based on Baidu's self-developed Möbius algorithm, a graph-based approximate nearest neighbor search algorithm for maximum inner product search (MIPS). This module provides python interface, supports numpy and tensor type vectors, and supports L2 and Inner Product distance calculation. @@ -57,7 +59,7 @@ brew install gcc 1. If prompted with `Error: Running Homebrew as root is extremely dangerous and no longer supported... `, refer to this [link](https://jingyan.baidu.com/article/e52e3615057a2840c60c519c.html) 2. If prompted with `Error: Failure while executing; tar --extract --no-same-owner --file... `, refer to this [link](https://blog.csdn.net/Dawn510/article/details/117787358). -After installation the compiled executable is copied under /usr/local/bin, look at the gcc in this folder: +After installation the compiled executable is copied under /usr/local/bin, look at the gcc in this folder: ``` ls /usr/local/bin/gcc* diff --git a/docs/en/tutorials/quick_start_recognition_en.md b/docs/en/tutorials/quick_start_recognition_en.md index dcea1b93c..cd14d2bb7 100644 --- a/docs/en/tutorials/quick_start_recognition_en.md +++ b/docs/en/tutorials/quick_start_recognition_en.md @@ -40,11 +40,11 @@ The detection model with the recognition inference model for the 4 directions (L | Logo Recognition Model | Logo Scenario | [Model Download Link](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/inference/logo_rec_ResNet50_Logo3K_v1.0_infer.tar) | [inference_logo.yaml](../../../deploy/configs/inference_logo.yaml) | [build_logo.yaml](../../../deploy/configs/build_logo.yaml) | | Cartoon Face Recognition Model| Cartoon Face Scenario | [Model Download Link](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/inference/cartoon_rec_ResNet50_iCartoon_v1.0_infer.tar) | [inference_cartoon.yaml](../../../deploy/configs/inference_cartoon.yaml) | [build_cartoon.yaml](../../../deploy/configs/build_cartoon.yaml) | | Vehicle Fine-Grained Classfication Model | Vehicle Scenario | [Model Download Link](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/inference/vehicle_cls_ResNet50_CompCars_v1.0_infer.tar) | [inference_vehicle.yaml](../../../deploy/configs/inference_vehicle.yaml) | [build_vehicle.yaml](../../../deploy/configs/build_vehicle.yaml) | -| Product Recignition Model | Product Scenario | [Model Download Link](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/inference/product_ResNet50_vd_Inshop_v1.0_infer.tar) | [inference_product.yaml](../../../deploy/configs/inference_product.yaml) | [build_product.yaml](../../../deploy/configs/build_product.yaml) | +| Product Recignition Model | Product Scenario | [Model Download Link](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/inference/product_ResNet50_vd_aliproduct_v1.0_infer.tar) | [inference_product.yaml](../../../deploy/configs/inference_product.yaml) | [build_product.yaml](../../../deploy/configs/build_product.yaml) | | Vehicle ReID Model | Vehicle ReID Scenario | [Model Download Link](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/inference/vehicle_reid_ResNet50_VERIWild_v1.0_infer.tar) | - | - | -Demo data in this tutorial can be downloaded here: [download link](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/data/recognition_demo_data_en_v1.0.tar). +Demo data in this tutorial can be downloaded here: [download link](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/data/recognition_demo_data_en_v1.1.tar). **Attention** diff --git a/docs/images/ml_illustration.jpg b/docs/images/ml_illustration.jpg new file mode 100644 index 000000000..69dced96b Binary files /dev/null and b/docs/images/ml_illustration.jpg differ diff --git a/docs/images/ml_pipeline.jpg b/docs/images/ml_pipeline.jpg new file mode 100644 index 000000000..cac650892 Binary files /dev/null and b/docs/images/ml_pipeline.jpg differ diff --git a/docs/images/wx_group.png b/docs/images/wx_group.png index 60259dd4e..4a410ffc8 100644 Binary files a/docs/images/wx_group.png and b/docs/images/wx_group.png differ diff --git a/docs/zh_CN/metric_learning.md b/docs/zh_CN/metric_learning.md new file mode 100644 index 000000000..6c94c87f4 --- /dev/null +++ b/docs/zh_CN/metric_learning.md @@ -0,0 +1,26 @@ +# Metric Learning + +## 简介 + 在机器学习中,我们经常会遇到度量数据间距离的问题。一般来说,对于可度量的数据,我们可以直接通过欧式距离(Euclidean Distance),向量内积(Inner Product)或者是余弦相似度(Cosine Similarity)来进行计算。但对于非结构化数据来说,我们却很难进行这样的操作,如计算一段视频和一首音乐的匹配程度。由于数据格式的不同,我们难以直接进行上述的向量运算,但先验知识告诉我们ED(laugh_video, laugh_music) < ED(laugh_video, blue_music), 如何去有效得表征这种”距离”关系呢? 这就是Metric Learning所要研究的课题。 + + Metric learning全称是 Distance Metric Learning,它是通过机器学习的形式,根据训练数据,自动构造出一种基于特定任务的度量函数。Metric Learning的目标是学习一个变换函数(线性非线性均可)L,将数据点从原始的向量空间映射到一个新的向量空间,在新的向量空间里相似点的距离更近,非相似点的距离更远,使得度量更符合任务的要求,如下图所示。 Deep Metric Learning,就是用深度神经网络来拟合这个变换函数。 +![example](../images/ml_illustration.jpg) + + +## 应用 + Metric Learning技术在生活实际中应用广泛,如我们耳熟能详的人脸识别(Face Recognition)、行人重识别(Person ReID)、图像检索(Image Retrieval)、细粒度分类(Fine-gained classification)等. 随着深度学习在工业实践中越来越广泛的应用,目前大家研究的方向基本都偏向于Deep Metric Learning(DML). + + 一般来说, DML包含三个部分: 特征提取网络来map embedding, 一个采样策略来将一个mini-batch里的样本组合成很多个sub-set, 最后loss function在每个sub-set上计算loss. 如下图所示: + ![image](../images/ml_pipeline.jpg) + + +## 算法 + Metric Learning主要有如下两种学习范式: +### 1. Classification based: + 这是一类基于分类标签的Metric Learning方法。这类方法通过将每个样本分类到正确的类别中,来学习有效的特征表示,学习过程中需要每个样本的显式标签参与Loss计算。常见的算法有[L2-Softmax](https://arxiv.org/abs/1703.09507), [Large-margin Softmax](https://arxiv.org/abs/1612.02295), [Angular Softmax](https://arxiv.org/pdf/1704.08063.pdf), [NormFace](https://arxiv.org/abs/1704.06369), [AM-Softmax](https://arxiv.org/abs/1801.05599), [CosFace](https://arxiv.org/abs/1801.09414), [ArcFace](https://arxiv.org/abs/1801.07698)等。 + 这类方法也被称作是proxy-based, 因为其本质上优化的是样本和一堆proxies之间的相似度。 +### 2. Pairwise based: + 这是一类基于样本对的学习范式。他以样本对作为输入,通过直接学习样本对之间的相似度来得到有效的特征表示,常见的算法包括:[Contrastive loss](http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf), [Triplet loss](https://arxiv.org/abs/1503.03832), [Lifted-Structure loss](https://arxiv.org/abs/1511.06452), [N-pair loss](https://papers.nips.cc/paper/2016/file/6b180037abbebea991d8b1232f8a8ca9-Paper.pdf), [Multi-Similarity loss](https://arxiv.org/pdf/1904.06627.pdf)等 + +2020年发表的[CircleLoss](https://arxiv.org/abs/2002.10857),从一个全新的视角统一了两种学习范式,让研究人员和从业者对Metric Learning问题有了更进一步的思考。 + diff --git a/docs/zh_CN/tutorials/quick_start_recognition.md b/docs/zh_CN/tutorials/quick_start_recognition.md index 81518f390..2e2fab831 100644 --- a/docs/zh_CN/tutorials/quick_start_recognition.md +++ b/docs/zh_CN/tutorials/quick_start_recognition.md @@ -44,7 +44,7 @@ | 车辆ReID模型 | 车辆ReID场景 | [模型下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/inference/vehicle_reid_ResNet50_VERIWild_v1.0_infer.tar) | - | - | -本章节demo数据下载地址如下: [数据下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/data/recognition_demo_data_v1.0.tar)。 +本章节demo数据下载地址如下: [数据下载链接](https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/data/recognition_demo_data_v1.1.tar)。 **注意** diff --git a/hubconf.py b/hubconf.py index eb114bc20..b7f76745a 100644 --- a/hubconf.py +++ b/hubconf.py @@ -41,10 +41,15 @@ class _SysPathG(object): self.path) -with _SysPathG( - os.path.join( - os.path.dirname(os.path.abspath(__file__)), 'ppcls', 'arch')): - import backbone +with _SysPathG(os.path.dirname(os.path.abspath(__file__)), ): + import ppcls + import ppcls.arch.backbone as backbone + + def ppclas_init(): + if ppcls.utils.logger._logger is None: + ppcls.utils.logger.init_logger() + + ppclas_init() def _load_pretrained_parameters(model, name): url = 'https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/{}_pretrained.pdparams'.format( @@ -63,9 +68,8 @@ with _SysPathG( Returns: model: nn.Layer. Specific `AlexNet` model depends on args. """ + kwargs.update({'pretrained': pretrained}) model = backbone.AlexNet(**kwargs) - if pretrained: - model = _load_pretrained_parameters(model, 'AlexNet') return model @@ -80,9 +84,8 @@ with _SysPathG( Returns: model: nn.Layer. Specific `VGG11` model depends on args. """ + kwargs.update({'pretrained': pretrained}) model = backbone.VGG11(**kwargs) - if pretrained: - model = _load_pretrained_parameters(model, 'VGG11') return model @@ -97,9 +100,8 @@ with _SysPathG( Returns: model: nn.Layer. Specific `VGG13` model depends on args. """ + kwargs.update({'pretrained': pretrained}) model = backbone.VGG13(**kwargs) - if pretrained: - model = _load_pretrained_parameters(model, 'VGG13') return model @@ -114,9 +116,8 @@ with _SysPathG( Returns: model: nn.Layer. Specific `VGG16` model depends on args. """ + kwargs.update({'pretrained': pretrained}) model = backbone.VGG16(**kwargs) - if pretrained: - model = _load_pretrained_parameters(model, 'VGG16') return model @@ -131,9 +132,8 @@ with _SysPathG( Returns: model: nn.Layer. Specific `VGG19` model depends on args. """ + kwargs.update({'pretrained': pretrained}) model = backbone.VGG19(**kwargs) - if pretrained: - model = _load_pretrained_parameters(model, 'VGG19') return model @@ -149,9 +149,8 @@ with _SysPathG( Returns: model: nn.Layer. Specific `ResNet18` model depends on args. """ + kwargs.update({'pretrained': pretrained}) model = backbone.ResNet18(**kwargs) - if pretrained: - model = _load_pretrained_parameters(model, 'ResNet18') return model @@ -167,9 +166,8 @@ with _SysPathG( Returns: model: nn.Layer. Specific `ResNet34` model depends on args. """ + kwargs.update({'pretrained': pretrained}) model = backbone.ResNet34(**kwargs) - if pretrained: - model = _load_pretrained_parameters(model, 'ResNet34') return model @@ -185,9 +183,8 @@ with _SysPathG( Returns: model: nn.Layer. Specific `ResNet50` model depends on args. """ + kwargs.update({'pretrained': pretrained}) model = backbone.ResNet50(**kwargs) - if pretrained: - model = _load_pretrained_parameters(model, 'ResNet50') return model @@ -203,9 +200,8 @@ with _SysPathG( Returns: model: nn.Layer. Specific `ResNet101` model depends on args. """ + kwargs.update({'pretrained': pretrained}) model = backbone.ResNet101(**kwargs) - if pretrained: - model = _load_pretrained_parameters(model, 'ResNet101') return model @@ -221,9 +217,8 @@ with _SysPathG( Returns: model: nn.Layer. Specific `ResNet152` model depends on args. """ + kwargs.update({'pretrained': pretrained}) model = backbone.ResNet152(**kwargs) - if pretrained: - model = _load_pretrained_parameters(model, 'ResNet152') return model @@ -237,9 +232,8 @@ with _SysPathG( Returns: model: nn.Layer. Specific `SqueezeNet1_0` model depends on args. """ + kwargs.update({'pretrained': pretrained}) model = backbone.SqueezeNet1_0(**kwargs) - if pretrained: - model = _load_pretrained_parameters(model, 'SqueezeNet1_0') return model @@ -253,9 +247,8 @@ with _SysPathG( Returns: model: nn.Layer. Specific `SqueezeNet1_1` model depends on args. """ + kwargs.update({'pretrained': pretrained}) model = backbone.SqueezeNet1_1(**kwargs) - if pretrained: - model = _load_pretrained_parameters(model, 'SqueezeNet1_1') return model @@ -271,9 +264,8 @@ with _SysPathG( Returns: model: nn.Layer. Specific `DenseNet121` model depends on args. """ + kwargs.update({'pretrained': pretrained}) model = backbone.DenseNet121(**kwargs) - if pretrained: - model = _load_pretrained_parameters(model, 'DenseNet121') return model @@ -289,9 +281,8 @@ with _SysPathG( Returns: model: nn.Layer. Specific `DenseNet161` model depends on args. """ + kwargs.update({'pretrained': pretrained}) model = backbone.DenseNet161(**kwargs) - if pretrained: - model = _load_pretrained_parameters(model, 'DenseNet161') return model @@ -307,9 +298,8 @@ with _SysPathG( Returns: model: nn.Layer. Specific `DenseNet169` model depends on args. """ + kwargs.update({'pretrained': pretrained}) model = backbone.DenseNet169(**kwargs) - if pretrained: - model = _load_pretrained_parameters(model, 'DenseNet169') return model @@ -325,9 +315,8 @@ with _SysPathG( Returns: model: nn.Layer. Specific `DenseNet201` model depends on args. """ + kwargs.update({'pretrained': pretrained}) model = backbone.DenseNet201(**kwargs) - if pretrained: - model = _load_pretrained_parameters(model, 'DenseNet201') return model @@ -343,9 +332,8 @@ with _SysPathG( Returns: model: nn.Layer. Specific `DenseNet264` model depends on args. """ + kwargs.update({'pretrained': pretrained}) model = backbone.DenseNet264(**kwargs) - if pretrained: - model = _load_pretrained_parameters(model, 'DenseNet264') return model @@ -359,9 +347,8 @@ with _SysPathG( Returns: model: nn.Layer. Specific `InceptionV3` model depends on args. """ + kwargs.update({'pretrained': pretrained}) model = backbone.InceptionV3(**kwargs) - if pretrained: - model = _load_pretrained_parameters(model, 'InceptionV3') return model @@ -375,9 +362,8 @@ with _SysPathG( Returns: model: nn.Layer. Specific `InceptionV4` model depends on args. """ + kwargs.update({'pretrained': pretrained}) model = backbone.InceptionV4(**kwargs) - if pretrained: - model = _load_pretrained_parameters(model, 'InceptionV4') return model @@ -391,9 +377,8 @@ with _SysPathG( Returns: model: nn.Layer. Specific `GoogLeNet` model depends on args. """ + kwargs.update({'pretrained': pretrained}) model = backbone.GoogLeNet(**kwargs) - if pretrained: - model = _load_pretrained_parameters(model, 'GoogLeNet') return model @@ -407,9 +392,8 @@ with _SysPathG( Returns: model: nn.Layer. Specific `ShuffleNetV2_x0_25` model depends on args. """ + kwargs.update({'pretrained': pretrained}) model = backbone.ShuffleNetV2_x0_25(**kwargs) - if pretrained: - model = _load_pretrained_parameters(model, 'ShuffleNetV2_x0_25') return model @@ -423,9 +407,8 @@ with _SysPathG( Returns: model: nn.Layer. Specific `MobileNetV1` model depends on args. """ + kwargs.update({'pretrained': pretrained}) model = backbone.MobileNetV1(**kwargs) - if pretrained: - model = _load_pretrained_parameters(model, 'MobileNetV1') return model @@ -439,9 +422,8 @@ with _SysPathG( Returns: model: nn.Layer. Specific `MobileNetV1_x0_25` model depends on args. """ + kwargs.update({'pretrained': pretrained}) model = backbone.MobileNetV1_x0_25(**kwargs) - if pretrained: - model = _load_pretrained_parameters(model, 'MobileNetV1_x0_25') return model @@ -455,9 +437,8 @@ with _SysPathG( Returns: model: nn.Layer. Specific `MobileNetV1_x0_5` model depends on args. """ + kwargs.update({'pretrained': pretrained}) model = backbone.MobileNetV1_x0_5(**kwargs) - if pretrained: - model = _load_pretrained_parameters(model, 'MobileNetV1_x0_5') return model @@ -471,9 +452,8 @@ with _SysPathG( Returns: model: nn.Layer. Specific `MobileNetV1_x0_75` model depends on args. """ + kwargs.update({'pretrained': pretrained}) model = backbone.MobileNetV1_x0_75(**kwargs) - if pretrained: - model = _load_pretrained_parameters(model, 'MobileNetV1_x0_75') return model @@ -487,9 +467,8 @@ with _SysPathG( Returns: model: nn.Layer. Specific `MobileNetV2_x0_25` model depends on args. """ + kwargs.update({'pretrained': pretrained}) model = backbone.MobileNetV2_x0_25(**kwargs) - if pretrained: - model = _load_pretrained_parameters(model, 'MobileNetV2_x0_25') return model @@ -503,9 +482,8 @@ with _SysPathG( Returns: model: nn.Layer. Specific `MobileNetV2_x0_5` model depends on args. """ + kwargs.update({'pretrained': pretrained}) model = backbone.MobileNetV2_x0_5(**kwargs) - if pretrained: - model = _load_pretrained_parameters(model, 'MobileNetV2_x0_5') return model @@ -519,9 +497,8 @@ with _SysPathG( Returns: model: nn.Layer. Specific `MobileNetV2_x0_75` model depends on args. """ + kwargs.update({'pretrained': pretrained}) model = backbone.MobileNetV2_x0_75(**kwargs) - if pretrained: - model = _load_pretrained_parameters(model, 'MobileNetV2_x0_75') return model @@ -535,9 +512,8 @@ with _SysPathG( Returns: model: nn.Layer. Specific `MobileNetV2_x1_5` model depends on args. """ + kwargs.update({'pretrained': pretrained}) model = backbone.MobileNetV2_x1_5(**kwargs) - if pretrained: - model = _load_pretrained_parameters(model, 'MobileNetV2_x1_5') return model @@ -551,9 +527,8 @@ with _SysPathG( Returns: model: nn.Layer. Specific `MobileNetV2_x2_0` model depends on args. """ + kwargs.update({'pretrained': pretrained}) model = backbone.MobileNetV2_x2_0(**kwargs) - if pretrained: - model = _load_pretrained_parameters(model, 'MobileNetV2_x2_0') return model @@ -567,10 +542,8 @@ with _SysPathG( Returns: model: nn.Layer. Specific `MobileNetV3_large_x0_35` model depends on args. """ + kwargs.update({'pretrained': pretrained}) model = backbone.MobileNetV3_large_x0_35(**kwargs) - if pretrained: - model = _load_pretrained_parameters(model, - 'MobileNetV3_large_x0_35') return model @@ -584,10 +557,8 @@ with _SysPathG( Returns: model: nn.Layer. Specific `MobileNetV3_large_x0_5` model depends on args. """ + kwargs.update({'pretrained': pretrained}) model = backbone.MobileNetV3_large_x0_5(**kwargs) - if pretrained: - model = _load_pretrained_parameters(model, - 'MobileNetV3_large_x0_5') return model @@ -601,10 +572,8 @@ with _SysPathG( Returns: model: nn.Layer. Specific `MobileNetV3_large_x0_75` model depends on args. """ + kwargs.update({'pretrained': pretrained}) model = backbone.MobileNetV3_large_x0_75(**kwargs) - if pretrained: - model = _load_pretrained_parameters(model, - 'MobileNetV3_large_x0_75') return model @@ -618,10 +587,8 @@ with _SysPathG( Returns: model: nn.Layer. Specific `MobileNetV3_large_x1_0` model depends on args. """ + kwargs.update({'pretrained': pretrained}) model = backbone.MobileNetV3_large_x1_0(**kwargs) - if pretrained: - model = _load_pretrained_parameters(model, - 'MobileNetV3_large_x1_0') return model @@ -635,10 +602,8 @@ with _SysPathG( Returns: model: nn.Layer. Specific `MobileNetV3_large_x1_25` model depends on args. """ + kwargs.update({'pretrained': pretrained}) model = backbone.MobileNetV3_large_x1_25(**kwargs) - if pretrained: - model = _load_pretrained_parameters(model, - 'MobileNetV3_large_x1_25') return model @@ -652,10 +617,8 @@ with _SysPathG( Returns: model: nn.Layer. Specific `MobileNetV3_small_x0_35` model depends on args. """ + kwargs.update({'pretrained': pretrained}) model = backbone.MobileNetV3_small_x0_35(**kwargs) - if pretrained: - model = _load_pretrained_parameters(model, - 'MobileNetV3_small_x0_35') return model @@ -669,10 +632,8 @@ with _SysPathG( Returns: model: nn.Layer. Specific `MobileNetV3_small_x0_5` model depends on args. """ + kwargs.update({'pretrained': pretrained}) model = backbone.MobileNetV3_small_x0_5(**kwargs) - if pretrained: - model = _load_pretrained_parameters(model, - 'MobileNetV3_small_x0_5') return model @@ -686,10 +647,8 @@ with _SysPathG( Returns: model: nn.Layer. Specific `MobileNetV3_small_x0_75` model depends on args. """ + kwargs.update({'pretrained': pretrained}) model = backbone.MobileNetV3_small_x0_75(**kwargs) - if pretrained: - model = _load_pretrained_parameters(model, - 'MobileNetV3_small_x0_75') return model @@ -703,10 +662,8 @@ with _SysPathG( Returns: model: nn.Layer. Specific `MobileNetV3_small_x1_0` model depends on args. """ + kwargs.update({'pretrained': pretrained}) model = backbone.MobileNetV3_small_x1_0(**kwargs) - if pretrained: - model = _load_pretrained_parameters(model, - 'MobileNetV3_small_x1_0') return model @@ -720,10 +677,8 @@ with _SysPathG( Returns: model: nn.Layer. Specific `MobileNetV3_small_x1_25` model depends on args. """ + kwargs.update({'pretrained': pretrained}) model = backbone.MobileNetV3_small_x1_25(**kwargs) - if pretrained: - model = _load_pretrained_parameters(model, - 'MobileNetV3_small_x1_25') return model @@ -737,9 +692,8 @@ with _SysPathG( Returns: model: nn.Layer. Specific `ResNeXt101_32x4d` model depends on args. """ + kwargs.update({'pretrained': pretrained}) model = backbone.ResNeXt101_32x4d(**kwargs) - if pretrained: - model = _load_pretrained_parameters(model, 'ResNeXt101_32x4d') return model @@ -753,9 +707,8 @@ with _SysPathG( Returns: model: nn.Layer. Specific `ResNeXt101_64x4d` model depends on args. """ + kwargs.update({'pretrained': pretrained}) model = backbone.ResNeXt101_64x4d(**kwargs) - if pretrained: - model = _load_pretrained_parameters(model, 'ResNeXt101_64x4d') return model @@ -769,9 +722,8 @@ with _SysPathG( Returns: model: nn.Layer. Specific `ResNeXt152_32x4d` model depends on args. """ + kwargs.update({'pretrained': pretrained}) model = backbone.ResNeXt152_32x4d(**kwargs) - if pretrained: - model = _load_pretrained_parameters(model, 'ResNeXt152_32x4d') return model @@ -785,9 +737,8 @@ with _SysPathG( Returns: model: nn.Layer. Specific `ResNeXt152_64x4d` model depends on args. """ + kwargs.update({'pretrained': pretrained}) model = backbone.ResNeXt152_64x4d(**kwargs) - if pretrained: - model = _load_pretrained_parameters(model, 'ResNeXt152_64x4d') return model @@ -801,9 +752,8 @@ with _SysPathG( Returns: model: nn.Layer. Specific `ResNeXt50_32x4d` model depends on args. """ + kwargs.update({'pretrained': pretrained}) model = backbone.ResNeXt50_32x4d(**kwargs) - if pretrained: - model = _load_pretrained_parameters(model, 'ResNeXt50_32x4d') return model @@ -817,9 +767,8 @@ with _SysPathG( Returns: model: nn.Layer. Specific `ResNeXt50_64x4d` model depends on args. """ + kwargs.update({'pretrained': pretrained}) model = backbone.ResNeXt50_64x4d(**kwargs) - if pretrained: - model = _load_pretrained_parameters(model, 'ResNeXt50_64x4d') return model @@ -833,8 +782,7 @@ with _SysPathG( Returns: model: nn.Layer. Specific `ResNeXt50_64x4d` model depends on args. """ + kwargs.update({'pretrained': pretrained}) model = backbone.DarkNet53(**kwargs) - if pretrained: - model = _load_pretrained_parameters(model, 'DarkNet53') return model diff --git a/ppcls/data/__init__.py b/ppcls/data/__init__.py index 8d507330f..b442aa883 100644 --- a/ppcls/data/__init__.py +++ b/ppcls/data/__init__.py @@ -54,8 +54,9 @@ def create_operators(params): def build_dataloader(config, mode, device, use_dali=False, seed=None): - assert mode in ['Train', 'Eval', 'Test', 'Gallery', 'Query' - ], "Mode should be Train, Eval, Test, Gallery, Query" + assert mode in [ + 'Train', 'Eval', 'Test', 'Gallery', 'Query' + ], "Dataset mode should be Train, Eval, Test, Gallery, Query" # build dataset if use_dali: from ppcls.data.dataloader.dali import dali_dataloader diff --git a/ppcls/engine/engine.py b/ppcls/engine/engine.py index ad5c584f0..c24a91633 100644 --- a/ppcls/engine/engine.py +++ b/ppcls/engine/engine.py @@ -51,6 +51,11 @@ class Engine(object): self.config = config self.eval_mode = self.config["Global"].get("eval_mode", "classification") + if "Head" in self.config["Arch"]: + self.is_rec = True + else: + self.is_rec = False + # init logger self.output_dir = self.config['Global']['output_dir'] log_file = os.path.join(self.output_dir, self.config["Arch"]["name"], @@ -106,12 +111,19 @@ class Engine(object): self.config["DataLoader"], "Eval", self.device, self.use_dali) elif self.eval_mode == "retrieval": - self.gallery_dataloader = build_dataloader( - self.config["DataLoader"]["Eval"], "Gallery", self.device, - self.use_dali) - self.query_dataloader = build_dataloader( - self.config["DataLoader"]["Eval"], "Query", self.device, - self.use_dali) + self.gallery_query_dataloader = None + if len(self.config["DataLoader"]["Eval"].keys()) == 1: + key = list(self.config["DataLoader"]["Eval"].keys())[0] + self.gallery_query_dataloader = build_dataloader( + self.config["DataLoader"]["Eval"], key, self.device, + self.use_dali) + else: + self.gallery_dataloader = build_dataloader( + self.config["DataLoader"]["Eval"], "Gallery", + self.device, self.use_dali) + self.query_dataloader = build_dataloader( + self.config["DataLoader"]["Eval"], "Query", + self.device, self.use_dali) # build loss if self.mode == "train": diff --git a/ppcls/engine/evaluation/retrieval.py b/ppcls/engine/evaluation/retrieval.py index 49d9626f6..bb6d08d31 100644 --- a/ppcls/engine/evaluation/retrieval.py +++ b/ppcls/engine/evaluation/retrieval.py @@ -23,10 +23,15 @@ from ppcls.utils import logger def retrieval_eval(evaler, epoch_id=0): evaler.model.eval() # step1. build gallery - gallery_feas, gallery_img_id, gallery_unique_id = cal_feature( - evaler, name='gallery') - query_feas, query_img_id, query_query_id = cal_feature( - evaler, name='query') + if evaler.gallery_query_dataloader is not None: + gallery_feas, gallery_img_id, gallery_unique_id = cal_feature( + evaler, name='gallery_query') + query_feas, query_img_id, query_query_id = gallery_feas, gallery_img_id, gallery_unique_id + else: + gallery_feas, gallery_img_id, gallery_unique_id = cal_feature( + evaler, name='gallery') + query_feas, query_img_id, query_query_id = cal_feature( + evaler, name='query') # step2. do evaluation sim_block_size = evaler.config["Global"].get("sim_block_size", 64) @@ -93,6 +98,8 @@ def cal_feature(evaler, name='gallery'): dataloader = evaler.gallery_dataloader elif name == 'query': dataloader = evaler.query_dataloader + elif name == 'gallery_query': + dataloader = evaler.gallery_query_dataloader else: raise RuntimeError("Only support gallery or query dataset") @@ -124,7 +131,7 @@ def cal_feature(evaler, name='gallery'): feas_norm = paddle.sqrt( paddle.sum(paddle.square(batch_feas), axis=1, keepdim=True)) batch_feas = paddle.divide(batch_feas, feas_norm) - + # do binarize if evaler.config["Global"].get("feature_binarize") == "round": batch_feas = paddle.round(batch_feas).astype("float32") * 2.0 - 1.0 @@ -142,10 +149,10 @@ def cal_feature(evaler, name='gallery'): all_image_id = paddle.concat([all_image_id, batch[1]]) if has_unique_id: all_unique_id = paddle.concat([all_unique_id, batch[2]]) - + if evaler.use_dali: dataloader_tmp.reset() - + if paddle.distributed.get_world_size() > 1: feat_list = [] img_id_list = [] diff --git a/ppcls/engine/train/train.py b/ppcls/engine/train/train.py index 9e36a063e..73f225087 100644 --- a/ppcls/engine/train/train.py +++ b/ppcls/engine/train/train.py @@ -79,7 +79,7 @@ def train_epoch(trainer, epoch_id, print_batch_step): def forward(trainer, batch): - if trainer.eval_mode == "classification": + if not trainer.is_rec: return trainer.model(batch[0]) else: return trainer.model(batch[0], batch[1]) diff --git a/ppcls/loss/celoss.py b/ppcls/loss/celoss.py index 7bc3c06cb..134797fbf 100644 --- a/ppcls/loss/celoss.py +++ b/ppcls/loss/celoss.py @@ -29,7 +29,7 @@ class CELoss(nn.Layer): self.epsilon = epsilon def _labelsmoothing(self, target, class_num): - if target.shape[-1] != class_num: + if target.ndim == 1 or target.shape[-1] != class_num: one_hot_target = F.one_hot(target, class_num) else: one_hot_target = target diff --git a/requirements.txt b/requirements.txt index 9575884ff..db2e5a08c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,3 +8,4 @@ visualdl >= 2.0.0b scipy scikit-learn==0.23.2 gast==0.3.3 +faiss-cpu==1.7.1