mirror of
https://github.com/PaddlePaddle/PaddleClas.git
synced 2025-06-03 21:55:06 +08:00
fixbug bs=1 in build_gallery
This commit is contained in:
parent
c68d411ca9
commit
ad3bd3a342
@ -71,14 +71,26 @@ class GalleryBuilder(object):
|
||||
gallery_features = np.zeros(
|
||||
[len(gallery_images), config['embedding_size']], dtype=np.float32)
|
||||
|
||||
#construct batch imgs and do inference
|
||||
batch_size = config.get("batch_size", 32)
|
||||
batch_img = []
|
||||
for i, image_file in enumerate(tqdm(gallery_images)):
|
||||
img = cv2.imread(image_file)
|
||||
if img is None:
|
||||
logger.error("img empty, please check {}".format(image_file))
|
||||
exit()
|
||||
img = img[:, :, ::-1]
|
||||
rec_feat = self.rec_predictor.predict(img)
|
||||
gallery_features[i, :] = rec_feat
|
||||
batch_img.append(img)
|
||||
|
||||
if (i + 1) % batch_size == 0:
|
||||
rec_feat = self.rec_predictor.predict(batch_img)
|
||||
gallery_features[i - batch_size + 1:i + 1, :] = rec_feat
|
||||
batch_img = []
|
||||
|
||||
if len(batch_img) > 0:
|
||||
rec_feat = self.rec_predictor.predict(batch_img)
|
||||
gallery_features[-len(batch_img):, :] = rec_feat
|
||||
batch_img = []
|
||||
|
||||
# train index
|
||||
self.Searcher = Graph_Index(dist_type=config['dist_type'])
|
||||
|
Loading…
x
Reference in New Issue
Block a user