fixbug bs=1 in build_gallery

This commit is contained in:
lubin10 2021-07-07 06:22:21 +00:00
parent c68d411ca9
commit ad3bd3a342
2 changed files with 14 additions and 3 deletions

1
aa.txt
View File

@ -1 +0,0 @@
i have already fix the bug

View File

@ -71,14 +71,26 @@ class GalleryBuilder(object):
gallery_features = np.zeros( gallery_features = np.zeros(
[len(gallery_images), config['embedding_size']], dtype=np.float32) [len(gallery_images), config['embedding_size']], dtype=np.float32)
#construct batch imgs and do inference
batch_size = config.get("batch_size", 32)
batch_img = []
for i, image_file in enumerate(tqdm(gallery_images)): for i, image_file in enumerate(tqdm(gallery_images)):
img = cv2.imread(image_file) img = cv2.imread(image_file)
if img is None: if img is None:
logger.error("img empty, please check {}".format(image_file)) logger.error("img empty, please check {}".format(image_file))
exit() exit()
img = img[:, :, ::-1] img = img[:, :, ::-1]
rec_feat = self.rec_predictor.predict(img) batch_img.append(img)
gallery_features[i, :] = rec_feat
if (i + 1) % batch_size == 0:
rec_feat = self.rec_predictor.predict(batch_img)
gallery_features[i - batch_size + 1:i + 1, :] = rec_feat
batch_img = []
if len(batch_img) > 0:
rec_feat = self.rec_predictor.predict(batch_img)
gallery_features[-len(batch_img):, :] = rec_feat
batch_img = []
# train index # train index
self.Searcher = Graph_Index(dist_type=config['dist_type']) self.Searcher = Graph_Index(dist_type=config['dist_type'])