support for binary index build and search
parent
a368e3eb20
commit
8baf879adb
|
@ -0,0 +1,40 @@
|
|||
Global:
|
||||
#rec_inference_model_dir: "./models/product_ResNet50_vd_aliproduct_v1.0_infer"
|
||||
rec_inference_model_dir: "../inference"
|
||||
batch_size: 32
|
||||
use_gpu: True
|
||||
enable_mkldnn: True
|
||||
cpu_num_threads: 10
|
||||
enable_benchmark: True
|
||||
use_fp16: False
|
||||
ir_optim: True
|
||||
use_tensorrt: False
|
||||
gpu_mem: 8000
|
||||
enable_profile: False
|
||||
|
||||
RecPreProcess:
|
||||
transform_ops:
|
||||
- ResizeImage:
|
||||
size: 224
|
||||
- NormalizeImage:
|
||||
scale: 0.00392157
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: ''
|
||||
- ToCHWImage:
|
||||
|
||||
RecPostProcess:
|
||||
main_indicator: Binarize
|
||||
Binarize:
|
||||
method: "round"
|
||||
|
||||
# indexing engine config
|
||||
IndexProcess:
|
||||
index_method: "Flat" # supported: HNSW32, Flat
|
||||
index_dir: "./recognition_demo_data_v1.1/gallery_product/index_binary"
|
||||
image_root: "./recognition_demo_data_v1.1/gallery_product/"
|
||||
data_file: "./recognition_demo_data_v1.1/gallery_product/data_file.txt"
|
||||
index_operation: "new" # suported: "append", "remove", "new"
|
||||
delimiter: "\t"
|
||||
dist_type: "hamming"
|
||||
embedding_size: 512
|
|
@ -0,0 +1,60 @@
|
|||
Global:
|
||||
infer_imgs: "./recognition_demo_data_v1.1/test_product/daoxiangcunjinzhubing_6.jpg"
|
||||
det_inference_model_dir: "./models/ppyolov2_r50vd_dcn_mainbody_v1.0_infer"
|
||||
rec_inference_model_dir: "./models/product_ResNet50_vd_aliproduct_v1.0_infer"
|
||||
rec_nms_thresold: 0.05
|
||||
|
||||
batch_size: 1
|
||||
image_shape: [3, 640, 640]
|
||||
threshold: 0.2
|
||||
max_det_results: 5
|
||||
labe_list:
|
||||
- foreground
|
||||
|
||||
# inference engine config
|
||||
use_gpu: True
|
||||
enable_mkldnn: True
|
||||
cpu_num_threads: 10
|
||||
enable_benchmark: True
|
||||
use_fp16: False
|
||||
ir_optim: True
|
||||
use_tensorrt: False
|
||||
gpu_mem: 8000
|
||||
enable_profile: False
|
||||
|
||||
DetPreProcess:
|
||||
transform_ops:
|
||||
- DetResize:
|
||||
interp: 2
|
||||
keep_ratio: false
|
||||
target_size: [640, 640]
|
||||
- DetNormalizeImage:
|
||||
is_scale: true
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
- DetPermute: {}
|
||||
DetPostProcess: {}
|
||||
|
||||
RecPreProcess:
|
||||
transform_ops:
|
||||
- ResizeImage:
|
||||
size: 224
|
||||
- NormalizeImage:
|
||||
scale: 0.00392157
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: ''
|
||||
- ToCHWImage:
|
||||
|
||||
RecPostProcess:
|
||||
main_indicator: Binarize
|
||||
Binarize:
|
||||
method: "round"
|
||||
|
||||
# indexing engine config
|
||||
IndexProcess:
|
||||
binary_index: true
|
||||
index_dir: "./recognition_demo_data_v1.1/gallery_product/index_binary"
|
||||
return_k: 5
|
||||
score_thres: 0
|
||||
|
|
@ -28,7 +28,6 @@ from python.predict_rec import RecPredictor
|
|||
from utils import logger
|
||||
from utils import config
|
||||
|
||||
|
||||
def split_datafile(data_file, image_root, delimiter="\t"):
|
||||
'''
|
||||
data_file: image path and info, which can be splitted by spacer
|
||||
|
@ -70,8 +69,8 @@ class GalleryBuilder(object):
|
|||
|
||||
# when remove data in index, do not need extract fatures
|
||||
if operation_method != "remove":
|
||||
gallery_features = self._extract_features(gallery_images, config)
|
||||
|
||||
gallery_features = self._extract_features(gallery_images, config) #76 * 512
|
||||
|
||||
assert operation_method in [
|
||||
"new", "remove", "append"
|
||||
], "Only append, remove and new operation are supported"
|
||||
|
@ -104,11 +103,22 @@ class GalleryBuilder(object):
|
|||
if index_method == "IVF":
|
||||
index_method = index_method + str(
|
||||
min(int(len(gallery_images) // 8), 65536)) + ",Flat"
|
||||
|
||||
# for binary index, add B at head of index_method
|
||||
if config["dist_type"] == "hamming":
|
||||
index_method = "B" + index_method
|
||||
|
||||
#dist_type
|
||||
dist_type = faiss.METRIC_INNER_PRODUCT if config[
|
||||
"dist_type"] == "IP" else faiss.METRIC_L2
|
||||
index = faiss.index_factory(config["embedding_size"], index_method,
|
||||
dist_type)
|
||||
index = faiss.IndexIDMap2(index)
|
||||
|
||||
#build index
|
||||
if config["dist_type"] == "hamming":
|
||||
index = faiss.index_binary_factory(config["embedding_size"], index_method)
|
||||
else:
|
||||
index = faiss.index_factory(config["embedding_size"], index_method,
|
||||
dist_type)
|
||||
index = faiss.IndexIDMap2(index)
|
||||
ids = {}
|
||||
|
||||
if config["index_method"] == "HNSW32":
|
||||
|
@ -119,12 +129,17 @@ class GalleryBuilder(object):
|
|||
# calculate id for new data
|
||||
start_id = max(ids.keys()) + 1 if ids else 0
|
||||
ids_now = (
|
||||
np.arange(0, len(gallery_images)) + start_id).astype(np.int64)
|
||||
np.arange(0, len(gallery_images)) + start_id).astype(np.int64) #ids: just the number sequence
|
||||
|
||||
# only train when new index file
|
||||
if operation_method == "new":
|
||||
index.train(gallery_features)
|
||||
index.add_with_ids(gallery_features, ids_now)
|
||||
if config["dist_type"] == "hamming":
|
||||
index.add(gallery_features)
|
||||
else:
|
||||
index.train(gallery_features)
|
||||
|
||||
if not config["dist_type"] == "hamming":
|
||||
index.add_with_ids(gallery_features, ids_now)
|
||||
|
||||
for i, d in zip(list(ids_now), gallery_docs):
|
||||
ids[i] = d
|
||||
|
@ -142,15 +157,25 @@ class GalleryBuilder(object):
|
|||
del ids[k]
|
||||
|
||||
# store faiss index file and id_map file
|
||||
faiss.write_index(index,
|
||||
os.path.join(config["index_dir"], "vector.index"))
|
||||
if config["dist_type"] == "hamming":
|
||||
faiss.write_index_binary(index,
|
||||
os.path.join(config["index_dir"], "vector.index"))
|
||||
else:
|
||||
faiss.write_index(index,
|
||||
os.path.join(config["index_dir"], "vector.index"))
|
||||
|
||||
with open(os.path.join(config["index_dir"], "id_map.pkl"), 'wb') as fd:
|
||||
pickle.dump(ids, fd)
|
||||
|
||||
def _extract_features(self, gallery_images, config):
|
||||
|
||||
# extract gallery features
|
||||
gallery_features = np.zeros(
|
||||
[len(gallery_images), config['embedding_size']], dtype=np.float32)
|
||||
if config["dist_type"] == "hamming":
|
||||
gallery_features = np.zeros(
|
||||
[len(gallery_images), config['embedding_size'] // 8], dtype=np.uint8)
|
||||
else:
|
||||
gallery_features = np.zeros(
|
||||
[len(gallery_images), config['embedding_size']], dtype=np.float32)
|
||||
|
||||
#construct batch imgs and do inference
|
||||
batch_size = config.get("batch_size", 32)
|
||||
|
@ -164,7 +189,7 @@ class GalleryBuilder(object):
|
|||
batch_img.append(img)
|
||||
|
||||
if (i + 1) % batch_size == 0:
|
||||
rec_feat = self.rec_predictor.predict(batch_img)
|
||||
rec_feat = self.rec_predictor.predict(batch_img) #32 * 512
|
||||
gallery_features[i - batch_size + 1:i + 1, :] = rec_feat
|
||||
batch_img = []
|
||||
|
||||
|
@ -172,6 +197,7 @@ class GalleryBuilder(object):
|
|||
rec_feat = self.rec_predictor.predict(batch_img)
|
||||
gallery_features[-len(batch_img):, :] = rec_feat
|
||||
batch_img = []
|
||||
|
||||
return gallery_features
|
||||
|
||||
|
||||
|
|
|
@ -62,6 +62,7 @@ class Topk(object):
|
|||
def parse_class_id_map(self, class_id_map_file):
|
||||
if class_id_map_file is None:
|
||||
return None
|
||||
|
||||
if not os.path.exists(class_id_map_file):
|
||||
print(
|
||||
"Warning: If want to use your own label_dict, please input legal path!\nOtherwise label_names will be empty!"
|
||||
|
@ -126,3 +127,42 @@ class SavePreLabel(object):
|
|||
output_dir = self.save_dir(str(id))
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
shutil.copy(image_file, output_dir)
|
||||
|
||||
class Binarize(object):
|
||||
def __init__(self, method = "round"):
|
||||
self.method = method
|
||||
self.unit = np.array([[128, 64, 32, 16, 8, 4, 2, 1]]).T
|
||||
|
||||
def __call__(self, x, file_names=None):
|
||||
if self.method == "round":
|
||||
x = np.round(x + 1).astype("uint8") - 1
|
||||
|
||||
if self.method == "sign":
|
||||
x = ((np.sign(x) + 1) / 2).astype("uint8")
|
||||
|
||||
embedding_size = x.shape[1]
|
||||
assert embedding_size % 8 == 0, "The Binary index only support vectors with sizes multiple of 8"
|
||||
|
||||
byte = np.zeros([x.shape[0], embedding_size // 8], dtype=np.uint8)
|
||||
for i in range(embedding_size // 8):
|
||||
byte[:, i:i+1] = np.dot(x[:, i * 8: (i + 1)* 8], self.unit)
|
||||
|
||||
return byte
|
||||
|
||||
if __name__== "__main__":
|
||||
a = Binarize()
|
||||
x = np.random.random((31, 64)).astype('float32')
|
||||
|
||||
y = a(x)
|
||||
print(y)
|
||||
print(y.shape)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -47,8 +47,14 @@ class SystemPredictor(object):
|
|||
index_dir, "vector.index")), "vector.index not found ..."
|
||||
assert os.path.exists(os.path.join(
|
||||
index_dir, "id_map.pkl")), "id_map.pkl not found ... "
|
||||
self.Searcher = faiss.read_index(
|
||||
os.path.join(index_dir, "vector.index"))
|
||||
|
||||
if config['IndexProcess'].get("binary_index", False):
|
||||
self.Searcher = faiss.read_index_binary(
|
||||
os.path.join(index_dir, "vector.index"))
|
||||
else:
|
||||
self.Searcher = faiss.read_index(
|
||||
os.path.join(index_dir, "vector.index"))
|
||||
|
||||
with open(os.path.join(index_dir, "id_map.pkl"), "rb") as fd:
|
||||
self.id_map = pickle.load(fd)
|
||||
|
||||
|
@ -105,6 +111,7 @@ class SystemPredictor(object):
|
|||
rec_results = self.rec_predictor.predict(crop_img)
|
||||
preds["bbox"] = [xmin, ymin, xmax, ymax]
|
||||
scores, docs = self.Searcher.search(rec_results, self.return_k)
|
||||
|
||||
# just top-1 result will be returned for the final
|
||||
if scores[0][0] >= self.config["IndexProcess"]["score_thres"]:
|
||||
preds["rec_docs"] = self.id_map[docs[0][0]].split()[1]
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
from paddle import nn
|
||||
|
||||
import paddle
|
||||
|
||||
class IdentityHead(nn.Layer):
|
||||
def __init__(self):
|
||||
super(IdentityHead, self).__init__()
|
||||
|
||||
def forward(self, x, label=None):
|
||||
return {"features": x, "logits": None}
|
||||
return {"features": x, "logits": None}
|
|
@ -378,7 +378,6 @@ class ExportModel(nn.Layer):
|
|||
self.infer_output_key = config.get("infer_output_key", None)
|
||||
if self.infer_output_key == "features" and isinstance(self.base_model,
|
||||
RecModel):
|
||||
self.base_model.head = IdentityHead()
|
||||
if config.get("infer_add_softmax", True):
|
||||
self.softmax = nn.Softmax(axis=-1)
|
||||
else:
|
||||
|
@ -394,10 +393,13 @@ class ExportModel(nn.Layer):
|
|||
x = self.base_model(x)
|
||||
if isinstance(x, list):
|
||||
x = x[0]
|
||||
|
||||
if self.infer_model_name is not None:
|
||||
x = x[self.infer_model_name]
|
||||
|
||||
if self.infer_output_key is not None:
|
||||
x = x[self.infer_output_key]
|
||||
|
||||
if self.softmax is not None:
|
||||
x = self.softmax(x)
|
||||
return x
|
||||
|
|
Loading…
Reference in New Issue