fixbug bs=1 in build_gallery

This commit is contained in:
lubin10 2021-07-07 06:22:21 +00:00
parent c68d411ca9
commit ad3bd3a342
2 changed files with 14 additions and 3 deletions

1
aa.txt
View File

@ -1 +0,0 @@
i have already fix the bug

View File

@ -71,14 +71,26 @@ class GalleryBuilder(object):
gallery_features = np.zeros(
[len(gallery_images), config['embedding_size']], dtype=np.float32)
#construct batch imgs and do inference
batch_size = config.get("batch_size", 32)
batch_img = []
for i, image_file in enumerate(tqdm(gallery_images)):
img = cv2.imread(image_file)
if img is None:
logger.error("img empty, please check {}".format(image_file))
exit()
img = img[:, :, ::-1]
rec_feat = self.rec_predictor.predict(img)
gallery_features[i, :] = rec_feat
batch_img.append(img)
if (i + 1) % batch_size == 0:
rec_feat = self.rec_predictor.predict(batch_img)
gallery_features[i - batch_size + 1:i + 1, :] = rec_feat
batch_img = []
if len(batch_img) > 0:
rec_feat = self.rec_predictor.predict(batch_img)
gallery_features[-len(batch_img):, :] = rec_feat
batch_img = []
# train index
self.Searcher = Graph_Index(dist_type=config['dist_type'])