Upstream changes to big batch search (#3170)
Summary: Pull Request resolved: https://github.com/facebookresearch/faiss/pull/3170 Logging info, adding to heap and wait_in and out times. Reviewed By: algoriddle Differential Revision: D52034667 fbshipit-source-id: 8ab864c5c43d534d094c6e81bb810c74e20c9ac2pull/3175/head
parent
79f558f1d9
commit
be1242775a
|
@ -6,6 +6,7 @@
|
|||
import time
|
||||
import pickle
|
||||
import os
|
||||
import logging
|
||||
from multiprocessing.pool import ThreadPool
|
||||
import threading
|
||||
import _thread
|
||||
|
@ -41,7 +42,7 @@ class BigBatchSearcher:
|
|||
self.use_float16 = use_float16
|
||||
keep_max = faiss.is_similarity_metric(index.metric_type)
|
||||
self.rh = faiss.ResultHeap(len(xq), k, keep_max=keep_max)
|
||||
self.t_accu = [0] * 5
|
||||
self.t_accu = [0] * 6
|
||||
self.t_display = self.t0 = time.time()
|
||||
|
||||
def start_t_accu(self):
|
||||
|
@ -74,11 +75,12 @@ class BigBatchSearcher:
|
|||
f"[{t:.1f} s] list {l}/{self.index.nlist} "
|
||||
f"times prep q {self.t_accu[0]:.3f} prep b {self.t_accu[1]:.3f} "
|
||||
f"comp {self.t_accu[2]:.3f} res {self.t_accu[3]:.3f} "
|
||||
f"wait {self.t_accu[4]:.3f} "
|
||||
f"wait in {self.t_accu[4]:.3f} "
|
||||
f"wait out {self.t_accu[5]:.3f} "
|
||||
f"eta {datetime.timedelta(seconds=t*self.index.nlist/(l+1)-t)} "
|
||||
f"mem {faiss.get_mem_usage_kb()}",
|
||||
end="\r" if self.verbose <= 2 else "\n",
|
||||
flush=True,
|
||||
end="\r" if self.verbose <= 2 else "\n",
|
||||
flush=True,
|
||||
)
|
||||
self.t_display = time.time()
|
||||
|
||||
|
@ -293,7 +295,7 @@ def big_batch_search(
|
|||
)
|
||||
mem_tot = mem_queries + mem_assign + mem_res
|
||||
if verbose > 0:
|
||||
print(
|
||||
logging.info(
|
||||
f"memory: queries {mem_queries} assign {mem_assign} "
|
||||
f"result {mem_res} total {mem_tot} = {mem_tot / (1<<30):.3f} GiB"
|
||||
)
|
||||
|
@ -312,8 +314,8 @@ def big_batch_search(
|
|||
)
|
||||
|
||||
bbs.decode_func = comp.decode_func
|
||||
bbs.by_residual = comp.by_residual
|
||||
|
||||
bbs.by_residual = comp.by_residual
|
||||
if q_assign is None:
|
||||
bbs.coarse_quantization()
|
||||
else:
|
||||
|
@ -327,11 +329,11 @@ def big_batch_search(
|
|||
if checkpoint is not None:
|
||||
assert (start_list, end_list) == (0, index.nlist)
|
||||
if os.path.exists(checkpoint):
|
||||
print("recovering checkpoint", checkpoint)
|
||||
logging.info(f"recovering checkpoint: {checkpoint}")
|
||||
completed = bbs.read_checkpoint(checkpoint)
|
||||
print(" already completed", len(completed))
|
||||
logging.info(f" already completed: {len(completed)}")
|
||||
else:
|
||||
print("no checkpoint: starting from scratch")
|
||||
logging.info("no checkpoint: starting from scratch")
|
||||
|
||||
if threaded == 0:
|
||||
# simple sequential version
|
||||
|
@ -414,10 +416,10 @@ def big_batch_search(
|
|||
|
||||
def prepare_task(task_id, output_queue, input_queue=None):
|
||||
try:
|
||||
# print(f"Prepare start: {task_id}")
|
||||
logging.info(f"Prepare start: {task_id}")
|
||||
q_subset, xq_l, list_ids, xb_l = bbs.prepare_bucket(task_id)
|
||||
output_queue.put((task_id, q_subset, xq_l, list_ids, xb_l))
|
||||
# print(f"Prepare end: {task_id}")
|
||||
logging.info(f"Prepare end: {task_id}")
|
||||
except:
|
||||
traceback.print_exc()
|
||||
_thread.interrupt_main()
|
||||
|
@ -425,18 +427,19 @@ def big_batch_search(
|
|||
|
||||
def compute_task(task_id, output_queue, input_queue):
|
||||
try:
|
||||
# print(f"Compute start: {task_id}")
|
||||
t_wait = 0
|
||||
logging.info(f"Compute start: {task_id}")
|
||||
t_wait_out = 0
|
||||
while True:
|
||||
t0 = time.time()
|
||||
logging.info(f'Compute input: task {task_id}')
|
||||
input_value = input_queue.get()
|
||||
t_wait += time.time() - t0
|
||||
t_wait_in = time.time() - t0
|
||||
if input_value is None:
|
||||
# signal for other compute tasks
|
||||
input_queue.put(None)
|
||||
break
|
||||
centroid, q_subset, xq_l, list_ids, xb_l = input_value
|
||||
# print(f'Compute work start: task {task_id}, centroid {centroid}')
|
||||
logging.info(f'Compute work: task {task_id}, centroid {centroid}')
|
||||
t0 = time.time()
|
||||
if computation_threads > 1:
|
||||
D, I = comp.block_search(
|
||||
|
@ -445,13 +448,13 @@ def big_batch_search(
|
|||
else:
|
||||
D, I = comp.block_search(xq_l, xb_l, list_ids, k)
|
||||
t_compute = time.time() - t0
|
||||
# print(f'Compute work end: task {task_id}, centroid {centroid}')
|
||||
logging.info(f'Compute output: task {task_id}, centroid {centroid}')
|
||||
t0 = time.time()
|
||||
output_queue.put(
|
||||
(centroid, t_wait, t_compute, q_subset, D, list_ids, I)
|
||||
(centroid, t_wait_in, t_wait_out, t_compute, q_subset, D, list_ids, I)
|
||||
)
|
||||
t_wait = time.time() - t0
|
||||
# print(f"Compute end: {task_id}")
|
||||
t_wait_out = time.time() - t0
|
||||
logging.info(f"Compute end: {task_id}")
|
||||
except:
|
||||
traceback.print_exc()
|
||||
_thread.interrupt_main()
|
||||
|
@ -480,21 +483,25 @@ def big_batch_search(
|
|||
|
||||
t_checkpoint = time.time()
|
||||
while True:
|
||||
logging.info("Waiting for result")
|
||||
value = compute_to_main_queue.get()
|
||||
if not value:
|
||||
break
|
||||
centroid, t_wait, t_compute, q_subset, D, list_ids, I = value
|
||||
centroid, t_wait_in, t_wait_out, t_compute, q_subset, D, list_ids, I = value
|
||||
# to test checkpointing
|
||||
if centroid == crash_at:
|
||||
1 / 0
|
||||
bbs.t_accu[2] += t_compute
|
||||
bbs.t_accu[4] += t_wait
|
||||
bbs.t_accu[4] += t_wait_in
|
||||
bbs.t_accu[5] += t_wait_out
|
||||
logging.info(f"Adding to heap start: centroid {centroid}")
|
||||
bbs.add_results_to_heap(q_subset, D, list_ids, I)
|
||||
logging.info(f"Adding to heap end: centroid {centroid}")
|
||||
completed.add(centroid)
|
||||
bbs.report(centroid)
|
||||
if checkpoint is not None:
|
||||
if time.time() - t_checkpoint > checkpoint_freq:
|
||||
print("writing checkpoint")
|
||||
logging.info("writing checkpoint")
|
||||
bbs.write_checkpoint(checkpoint, completed)
|
||||
t_checkpoint = time.time()
|
||||
|
||||
|
|
Loading…
Reference in New Issue