forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy paththread_pool.cpp
140 lines (118 loc) · 3.27 KB
/
thread_pool.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
#include <ATen/core/thread_pool.h>
#include <ATen/core/ivalue.h>
namespace c10 {
ThreadPool::ThreadPool(std::size_t pool_size, int numa_node_id)
: threads_(pool_size),
running_(true),
complete_(true),
available_(pool_size),
total_(pool_size),
numa_node_id_(numa_node_id) {
for (std::size_t i = 0; i < pool_size; ++i) {
threads_[i] = std::thread(std::bind(&ThreadPool::main_loop, this, i));
}
}
ThreadPool::~ThreadPool() {
// Set running flag to false then notify all threads.
{
std::unique_lock<std::mutex> lock(mutex_);
running_ = false;
condition_.notify_all();
}
for (auto& t : threads_) {
try {
t.join();
} catch (const std::exception&) {
}
}
}
size_t ThreadPool::size() const {
return threads_.size();
}
size_t ThreadPool::numAvailable() const {
return available_;
}
bool ThreadPool::inThreadPool() const {
for (auto& thread : threads_) {
if (thread.get_id() == std::this_thread::get_id()) {
return true;
}
}
return false;
}
void ThreadPool::run(const std::function<void()>& func) {
std::unique_lock<std::mutex> lock(mutex_);
// Set task and signal condition variable so that a worker thread will
// wake up and use the task.
tasks_.push(task_element_t(func));
complete_ = false;
condition_.notify_one();
}
void ThreadPool::waitWorkComplete() {
std::unique_lock<std::mutex> lock(mutex_);
while (!complete_) {
completed_.wait(lock);
}
}
void ThreadPool::workOnTasksUntilCompleted(
c10::intrusive_ptr<ivalue::Future> future) {
if (future->completed()) {
return;
}
std::condition_variable finished;
future->addCallback([&] { finished.notify_all(); });
std::unique_lock<std::mutex> future_lock(future->get_mutex());
while (!future->completed()) {
finished.wait(future_lock);
}
}
void ThreadPool::main_loop(std::size_t index) {
init_thread();
while (running_) {
// Wait on condition variable while the task is empty and
// the pool is still running.
std::unique_lock<std::mutex> lock(mutex_);
while (tasks_.empty() && running_) {
condition_.wait(lock);
}
// If pool is no longer running, break out of loop.
if (!running_) {
break;
}
// Copy task locally and remove from the queue. This is
// done within its own scope so that the task object is
// destructed immediately after running the task. This is
// useful in the event that the function contains
// shared_ptr arguments bound via bind.
{
auto tasks = tasks_.front();
tasks_.pop();
// Decrement count, indicating thread is no longer available.
--available_;
lock.unlock();
// Run the task.
try {
if (tasks.run_with_id) {
tasks.with_id(index);
} else {
tasks.no_id();
}
} catch (const std::exception&) {
}
// Update status of empty, maybe
// Need to recover the lock first
lock.lock();
// Increment count, indicating thread is available.
++available_;
if (tasks_.empty() && available_ == total_) {
complete_ = true;
completed_.notify_one();
}
}
} // while running_
}
ThreadPool& global_work_queue() {
static ThreadPool thread_pool(1);
return thread_pool;
}
} // namespace c10