PaddleClas/deploy/python/build_gallery.py

211 lines
7.9 KiB
Python

# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import cv2
import faiss
import numpy as np
from tqdm import tqdm
import pickle
from paddleclas.deploy.utils import logger, config
from paddleclas.deploy.utils.predictor import Predictor
from paddleclas.deploy.python.predict_rec import RecPredictor
from paddleclas.deploy.python.predict_rec import RecPredictor
def split_datafile(data_file, image_root, delimiter="\t"):
'''
data_file: image path and info, which can be splitted by spacer
image_root: image path root
delimiter: delimiter
'''
gallery_images = []
gallery_docs = []
with open(data_file, 'r', encoding='utf-8') as f:
lines = f.readlines()
for _, ori_line in enumerate(lines):
line = ori_line.strip().split(delimiter)
text_num = len(line)
assert text_num >= 2, f"line({ori_line}) must be splitted into at least 2 parts, but got {text_num}"
image_file = os.path.join(image_root, line[0])
gallery_images.append(image_file)
gallery_docs.append(ori_line.strip())
return gallery_images, gallery_docs
class GalleryBuilder(object):
def __init__(self, config):
self.config = config
self.rec_predictor = RecPredictor(config)
assert 'IndexProcess' in config.keys(), "Index config not found ... "
self.build(config['IndexProcess'])
def build(self, config):
'''
build index from scratch
'''
operation_method = config.get("index_operation", "new").lower()
gallery_images, gallery_docs = split_datafile(
config['data_file'], config['image_root'], config['delimiter'])
# when remove data in index, do not need extract fatures
if operation_method != "remove":
gallery_features = self._extract_features(gallery_images, config)
assert operation_method in [
"new", "remove", "append"
], "Only append, remove and new operation are supported"
# vector.index: faiss index file
# id_map.pkl: use this file to map id to image_doc
if operation_method in ["remove", "append"]:
# if remove or append, vector.index and id_map.pkl must exist
assert os.path.join(
config["index_dir"], "vector.index"
), "The vector.index dose not exist in {} when 'index_operation' is not None".format(
config["index_dir"])
assert os.path.join(
config["index_dir"], "id_map.pkl"
), "The id_map.pkl dose not exist in {} when 'index_operation' is not None".format(
config["index_dir"])
index = faiss.read_index(
os.path.join(config["index_dir"], "vector.index"))
with open(os.path.join(config["index_dir"], "id_map.pkl"),
'rb') as fd:
ids = pickle.load(fd)
assert index.ntotal == len(ids.keys(
)), "data number in index is not equal in in id_map"
else:
if not os.path.exists(config["index_dir"]):
os.makedirs(config["index_dir"], exist_ok=True)
index_method = config.get("index_method", "HNSW32")
# if IVF method, cal ivf number automaticlly
if index_method == "IVF":
index_method = index_method + str(
min(int(len(gallery_images) // 8), 65536)) + ",Flat"
# for binary index, add B at head of index_method
if config["dist_type"] == "hamming":
index_method = "B" + index_method
#dist_type
dist_type = faiss.METRIC_INNER_PRODUCT if config[
"dist_type"] == "IP" else faiss.METRIC_L2
#build index
if config["dist_type"] == "hamming":
index = faiss.index_binary_factory(config["embedding_size"],
index_method)
else:
index = faiss.index_factory(config["embedding_size"],
index_method, dist_type)
index = faiss.IndexIDMap2(index)
ids = {}
if config["index_method"] == "HNSW32":
logger.warning(
"The HNSW32 method dose not support 'remove' operation")
if operation_method != "remove":
# calculate id for new data
start_id = max(ids.keys()) + 1 if ids else 0
ids_now = (
np.arange(0, len(gallery_images)) + start_id).astype(np.int64)
# only train when new index file
if operation_method == "new":
if config["dist_type"] == "hamming":
index.add(gallery_features)
else:
index.train(gallery_features)
if not config["dist_type"] == "hamming":
index.add_with_ids(gallery_features, ids_now)
for i, d in zip(list(ids_now), gallery_docs):
ids[i] = d
else:
if config["index_method"] == "HNSW32":
raise RuntimeError(
"The index_method: HNSW32 dose not support 'remove' operation"
)
# remove ids in id_map, remove index data in faiss index
remove_ids = list(
filter(lambda k: ids.get(k) in gallery_docs, ids.keys()))
remove_ids = np.asarray(remove_ids)
index.remove_ids(remove_ids)
for k in remove_ids:
del ids[k]
# store faiss index file and id_map file
if config["dist_type"] == "hamming":
faiss.write_index_binary(
index, os.path.join(config["index_dir"], "vector.index"))
else:
faiss.write_index(
index, os.path.join(config["index_dir"], "vector.index"))
with open(os.path.join(config["index_dir"], "id_map.pkl"), 'wb') as fd:
pickle.dump(ids, fd)
def _extract_features(self, gallery_images, config):
# extract gallery features
if config["dist_type"] == "hamming":
gallery_features = np.zeros(
[len(gallery_images), config['embedding_size'] // 8],
dtype=np.uint8)
else:
gallery_features = np.zeros(
[len(gallery_images), config['embedding_size']],
dtype=np.float32)
#construct batch imgs and do inference
batch_size = config.get("batch_size", 32)
batch_img = []
for i, image_file in enumerate(tqdm(gallery_images)):
img = cv2.imread(image_file)
if img is None:
logger.error("img empty, please check {}".format(image_file))
exit()
img = img[:, :, ::-1]
batch_img.append(img)
if (i + 1) % batch_size == 0:
rec_feat = self.rec_predictor.predict(batch_img)
gallery_features[i - batch_size + 1:i + 1, :] = rec_feat
batch_img = []
if len(batch_img) > 0:
rec_feat = self.rec_predictor.predict(batch_img)
gallery_features[-len(batch_img):, :] = rec_feat
batch_img = []
return gallery_features
def main(config):
GalleryBuilder(config)
return
if __name__ == "__main__":
args = config.parse_args()
config = config.get_config(args.config, overrides=args.override, show=True)
main(config)