Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
WorkerThread.cpp
1 
2 /**
3  * Copyright (c) 2015-present, Facebook, Inc.
4  * All rights reserved.
5  *
6  * This source code is licensed under the CC-by-NC license found in the
7  * LICENSE file in the root directory of this source tree.
8  */
9 
10 // Copyright 2004-present Facebook. All Rights Reserved.
11 
12 #include "WorkerThread.h"
13 #include "../../FaissAssert.h"
14 
15 namespace faiss { namespace gpu {
16 
17 WorkerThread::WorkerThread() :
18  wantStop_(false) {
19  startThread();
20 
21  // Make sure that the thread has started before continuing
22  add([](){}).get();
23 }
24 
25 WorkerThread::~WorkerThread() {
26  stop();
27  waitForThreadExit();
28 }
29 
30 void
31 WorkerThread::startThread() {
32  thread_ = std::thread([this](){ threadMain(); });
33 }
34 
35 void
36 WorkerThread::stop() {
37  std::lock_guard<std::mutex> guard(mutex_);
38 
39  wantStop_ = true;
40  monitor_.notify_one();
41 }
42 
43 std::future<bool>
44 WorkerThread::add(std::function<void()> f) {
45  std::lock_guard<std::mutex> guard(mutex_);
46 
47  if (wantStop_) {
48  // The timer thread has been stopped, or we want to stop; we can't
49  // schedule anything else
50  std::promise<bool> p;
51  auto fut = p.get_future();
52 
53  // did not execute
54  p.set_value(false);
55  return fut;
56  }
57 
58  auto pr = std::promise<bool>();
59  auto fut = pr.get_future();
60 
61  queue_.emplace_back(std::make_pair(std::move(f), std::move(pr)));
62 
63  // Wake up our thread
64  monitor_.notify_one();
65  return fut;
66 }
67 
68 void
69 WorkerThread::threadMain() {
70  threadLoop();
71 
72  // Call all pending tasks
73  FAISS_ASSERT(wantStop_);
74 
75  for (auto& f : queue_) {
76  f.first();
77  f.second.set_value(true);
78  }
79 }
80 
81 void
82 WorkerThread::threadLoop() {
83  while (true) {
84  std::pair<std::function<void()>, std::promise<bool>> data;
85 
86  {
87  std::unique_lock<std::mutex> lock(mutex_);
88 
89  while (!wantStop_ && queue_.empty()) {
90  monitor_.wait(lock);
91  }
92 
93  if (wantStop_) {
94  return;
95  }
96 
97  data = std::move(queue_.front());
98  queue_.pop_front();
99  }
100 
101  data.first();
102  data.second.set_value(true);
103  }
104 }
105 
106 void
107 WorkerThread::waitForThreadExit() {
108  try {
109  thread_.join();
110  } catch (...) {
111  }
112 }
113 
114 } } // namespace
virtual void add(idx_t n, const float *x) override