Upstream changes to big batch search (#3170)

Summary:
Pull Request resolved: https://github.com/facebookresearch/faiss/pull/3170

Logging info, adding to heap and wait_in and out times.

Reviewed By: algoriddle

Differential Revision: D52034667

fbshipit-source-id: 8ab864c5c43d534d094c6e81bb810c74e20c9ac2
pull/3175/head
Maria Lomeli 2023-12-12 09:51:05 -08:00 committed by Facebook GitHub Bot
parent 79f558f1d9
commit be1242775a
1 changed files with 29 additions and 22 deletions

View File

@ -6,6 +6,7 @@
import time
import pickle
import os
import logging
from multiprocessing.pool import ThreadPool
import threading
import _thread
@ -41,7 +42,7 @@ class BigBatchSearcher:
self.use_float16 = use_float16
keep_max = faiss.is_similarity_metric(index.metric_type)
self.rh = faiss.ResultHeap(len(xq), k, keep_max=keep_max)
self.t_accu = [0] * 5
self.t_accu = [0] * 6
self.t_display = self.t0 = time.time()
def start_t_accu(self):
@ -74,11 +75,12 @@ class BigBatchSearcher:
f"[{t:.1f} s] list {l}/{self.index.nlist} "
f"times prep q {self.t_accu[0]:.3f} prep b {self.t_accu[1]:.3f} "
f"comp {self.t_accu[2]:.3f} res {self.t_accu[3]:.3f} "
f"wait {self.t_accu[4]:.3f} "
f"wait in {self.t_accu[4]:.3f} "
f"wait out {self.t_accu[5]:.3f} "
f"eta {datetime.timedelta(seconds=t*self.index.nlist/(l+1)-t)} "
f"mem {faiss.get_mem_usage_kb()}",
end="\r" if self.verbose <= 2 else "\n",
flush=True,
end="\r" if self.verbose <= 2 else "\n",
flush=True,
)
self.t_display = time.time()
@ -293,7 +295,7 @@ def big_batch_search(
)
mem_tot = mem_queries + mem_assign + mem_res
if verbose > 0:
print(
logging.info(
f"memory: queries {mem_queries} assign {mem_assign} "
f"result {mem_res} total {mem_tot} = {mem_tot / (1<<30):.3f} GiB"
)
@ -312,8 +314,8 @@ def big_batch_search(
)
bbs.decode_func = comp.decode_func
bbs.by_residual = comp.by_residual
bbs.by_residual = comp.by_residual
if q_assign is None:
bbs.coarse_quantization()
else:
@ -327,11 +329,11 @@ def big_batch_search(
if checkpoint is not None:
assert (start_list, end_list) == (0, index.nlist)
if os.path.exists(checkpoint):
print("recovering checkpoint", checkpoint)
logging.info(f"recovering checkpoint: {checkpoint}")
completed = bbs.read_checkpoint(checkpoint)
print(" already completed", len(completed))
logging.info(f" already completed: {len(completed)}")
else:
print("no checkpoint: starting from scratch")
logging.info("no checkpoint: starting from scratch")
if threaded == 0:
# simple sequential version
@ -414,10 +416,10 @@ def big_batch_search(
def prepare_task(task_id, output_queue, input_queue=None):
try:
# print(f"Prepare start: {task_id}")
logging.info(f"Prepare start: {task_id}")
q_subset, xq_l, list_ids, xb_l = bbs.prepare_bucket(task_id)
output_queue.put((task_id, q_subset, xq_l, list_ids, xb_l))
# print(f"Prepare end: {task_id}")
logging.info(f"Prepare end: {task_id}")
except:
traceback.print_exc()
_thread.interrupt_main()
@ -425,18 +427,19 @@ def big_batch_search(
def compute_task(task_id, output_queue, input_queue):
try:
# print(f"Compute start: {task_id}")
t_wait = 0
logging.info(f"Compute start: {task_id}")
t_wait_out = 0
while True:
t0 = time.time()
logging.info(f'Compute input: task {task_id}')
input_value = input_queue.get()
t_wait += time.time() - t0
t_wait_in = time.time() - t0
if input_value is None:
# signal for other compute tasks
input_queue.put(None)
break
centroid, q_subset, xq_l, list_ids, xb_l = input_value
# print(f'Compute work start: task {task_id}, centroid {centroid}')
logging.info(f'Compute work: task {task_id}, centroid {centroid}')
t0 = time.time()
if computation_threads > 1:
D, I = comp.block_search(
@ -445,13 +448,13 @@ def big_batch_search(
else:
D, I = comp.block_search(xq_l, xb_l, list_ids, k)
t_compute = time.time() - t0
# print(f'Compute work end: task {task_id}, centroid {centroid}')
logging.info(f'Compute output: task {task_id}, centroid {centroid}')
t0 = time.time()
output_queue.put(
(centroid, t_wait, t_compute, q_subset, D, list_ids, I)
(centroid, t_wait_in, t_wait_out, t_compute, q_subset, D, list_ids, I)
)
t_wait = time.time() - t0
# print(f"Compute end: {task_id}")
t_wait_out = time.time() - t0
logging.info(f"Compute end: {task_id}")
except:
traceback.print_exc()
_thread.interrupt_main()
@ -480,21 +483,25 @@ def big_batch_search(
t_checkpoint = time.time()
while True:
logging.info("Waiting for result")
value = compute_to_main_queue.get()
if not value:
break
centroid, t_wait, t_compute, q_subset, D, list_ids, I = value
centroid, t_wait_in, t_wait_out, t_compute, q_subset, D, list_ids, I = value
# to test checkpointing
if centroid == crash_at:
1 / 0
bbs.t_accu[2] += t_compute
bbs.t_accu[4] += t_wait
bbs.t_accu[4] += t_wait_in
bbs.t_accu[5] += t_wait_out
logging.info(f"Adding to heap start: centroid {centroid}")
bbs.add_results_to_heap(q_subset, D, list_ids, I)
logging.info(f"Adding to heap end: centroid {centroid}")
completed.add(centroid)
bbs.report(centroid)
if checkpoint is not None:
if time.time() - t_checkpoint > checkpoint_freq:
print("writing checkpoint")
logging.info("writing checkpoint")
bbs.write_checkpoint(checkpoint, completed)
t_checkpoint = time.time()