From be1242775a6e7c025c3ab48388d95bb6a5377d8c Mon Sep 17 00:00:00 2001 From: Maria Lomeli Date: Tue, 12 Dec 2023 09:51:05 -0800 Subject: [PATCH] Upstream changes to big batch search (#3170) Summary: Pull Request resolved: https://github.com/facebookresearch/faiss/pull/3170 Logging info, adding to heap and wait_in and out times. Reviewed By: algoriddle Differential Revision: D52034667 fbshipit-source-id: 8ab864c5c43d534d094c6e81bb810c74e20c9ac2 --- contrib/big_batch_search.py | 51 +++++++++++++++++++++---------------- 1 file changed, 29 insertions(+), 22 deletions(-) diff --git a/contrib/big_batch_search.py b/contrib/big_batch_search.py index 6b0fd36e9..440a538c1 100644 --- a/contrib/big_batch_search.py +++ b/contrib/big_batch_search.py @@ -6,6 +6,7 @@ import time import pickle import os +import logging from multiprocessing.pool import ThreadPool import threading import _thread @@ -41,7 +42,7 @@ class BigBatchSearcher: self.use_float16 = use_float16 keep_max = faiss.is_similarity_metric(index.metric_type) self.rh = faiss.ResultHeap(len(xq), k, keep_max=keep_max) - self.t_accu = [0] * 5 + self.t_accu = [0] * 6 self.t_display = self.t0 = time.time() def start_t_accu(self): @@ -74,11 +75,12 @@ class BigBatchSearcher: f"[{t:.1f} s] list {l}/{self.index.nlist} " f"times prep q {self.t_accu[0]:.3f} prep b {self.t_accu[1]:.3f} " f"comp {self.t_accu[2]:.3f} res {self.t_accu[3]:.3f} " - f"wait {self.t_accu[4]:.3f} " + f"wait in {self.t_accu[4]:.3f} " + f"wait out {self.t_accu[5]:.3f} " f"eta {datetime.timedelta(seconds=t*self.index.nlist/(l+1)-t)} " f"mem {faiss.get_mem_usage_kb()}", - end="\r" if self.verbose <= 2 else "\n", - flush=True, + end="\r" if self.verbose <= 2 else "\n", + flush=True, ) self.t_display = time.time() @@ -293,7 +295,7 @@ def big_batch_search( ) mem_tot = mem_queries + mem_assign + mem_res if verbose > 0: - print( + logging.info( f"memory: queries {mem_queries} assign {mem_assign} " f"result {mem_res} total {mem_tot} = {mem_tot / (1<<30):.3f} GiB" ) @@ -312,8 +314,8 @@ def big_batch_search( ) bbs.decode_func = comp.decode_func - bbs.by_residual = comp.by_residual + bbs.by_residual = comp.by_residual if q_assign is None: bbs.coarse_quantization() else: @@ -327,11 +329,11 @@ def big_batch_search( if checkpoint is not None: assert (start_list, end_list) == (0, index.nlist) if os.path.exists(checkpoint): - print("recovering checkpoint", checkpoint) + logging.info(f"recovering checkpoint: {checkpoint}") completed = bbs.read_checkpoint(checkpoint) - print(" already completed", len(completed)) + logging.info(f" already completed: {len(completed)}") else: - print("no checkpoint: starting from scratch") + logging.info("no checkpoint: starting from scratch") if threaded == 0: # simple sequential version @@ -414,10 +416,10 @@ def big_batch_search( def prepare_task(task_id, output_queue, input_queue=None): try: - # print(f"Prepare start: {task_id}") + logging.info(f"Prepare start: {task_id}") q_subset, xq_l, list_ids, xb_l = bbs.prepare_bucket(task_id) output_queue.put((task_id, q_subset, xq_l, list_ids, xb_l)) - # print(f"Prepare end: {task_id}") + logging.info(f"Prepare end: {task_id}") except: traceback.print_exc() _thread.interrupt_main() @@ -425,18 +427,19 @@ def big_batch_search( def compute_task(task_id, output_queue, input_queue): try: - # print(f"Compute start: {task_id}") - t_wait = 0 + logging.info(f"Compute start: {task_id}") + t_wait_out = 0 while True: t0 = time.time() + logging.info(f'Compute input: task {task_id}') input_value = input_queue.get() - t_wait += time.time() - t0 + t_wait_in = time.time() - t0 if input_value is None: # signal for other compute tasks input_queue.put(None) break centroid, q_subset, xq_l, list_ids, xb_l = input_value - # print(f'Compute work start: task {task_id}, centroid {centroid}') + logging.info(f'Compute work: task {task_id}, centroid {centroid}') t0 = time.time() if computation_threads > 1: D, I = comp.block_search( @@ -445,13 +448,13 @@ def big_batch_search( else: D, I = comp.block_search(xq_l, xb_l, list_ids, k) t_compute = time.time() - t0 - # print(f'Compute work end: task {task_id}, centroid {centroid}') + logging.info(f'Compute output: task {task_id}, centroid {centroid}') t0 = time.time() output_queue.put( - (centroid, t_wait, t_compute, q_subset, D, list_ids, I) + (centroid, t_wait_in, t_wait_out, t_compute, q_subset, D, list_ids, I) ) - t_wait = time.time() - t0 - # print(f"Compute end: {task_id}") + t_wait_out = time.time() - t0 + logging.info(f"Compute end: {task_id}") except: traceback.print_exc() _thread.interrupt_main() @@ -480,21 +483,25 @@ def big_batch_search( t_checkpoint = time.time() while True: + logging.info("Waiting for result") value = compute_to_main_queue.get() if not value: break - centroid, t_wait, t_compute, q_subset, D, list_ids, I = value + centroid, t_wait_in, t_wait_out, t_compute, q_subset, D, list_ids, I = value # to test checkpointing if centroid == crash_at: 1 / 0 bbs.t_accu[2] += t_compute - bbs.t_accu[4] += t_wait + bbs.t_accu[4] += t_wait_in + bbs.t_accu[5] += t_wait_out + logging.info(f"Adding to heap start: centroid {centroid}") bbs.add_results_to_heap(q_subset, D, list_ids, I) + logging.info(f"Adding to heap end: centroid {centroid}") completed.add(centroid) bbs.report(centroid) if checkpoint is not None: if time.time() - t_checkpoint > checkpoint_freq: - print("writing checkpoint") + logging.info("writing checkpoint") bbs.write_checkpoint(checkpoint, completed) t_checkpoint = time.time()