suanPan
Loading...
Searching...
No Matches
thread_pool.hpp
Go to the documentation of this file.
1#ifndef THREAD_POOL_HPP
2#define THREAD_POOL_HPP
3
4#include <atomic>
5#include <condition_variable>
6#include <exception>
7#include <functional>
8#include <future>
9#include <memory>
10#include <mutex>
11#include <queue>
12#include <thread>
13#include <type_traits>
14#include <utility>
15
16class [[nodiscard]] thread_pool {
17 using concurrency_t = std::invoke_result_t<decltype(std::thread::hardware_concurrency)>;
18
19 std::atomic<bool> running = false;
20
21 std::condition_variable task_available_cv = {};
22
23 std::condition_variable task_done_cv = {};
24
25 std::queue<std::function<void()>> tasks = {};
26
27 std::atomic<size_t> tasks_total = 0;
28
29 mutable std::mutex tasks_mutex = {};
30
31 concurrency_t thread_count = 0;
32
33 std::unique_ptr<std::thread[]> threads = nullptr;
34
35 std::atomic<bool> waiting = false;
36
37 void create_threads() {
38 running = true;
39 for(concurrency_t i = 0; i < thread_count; ++i) threads[i] = std::thread(&thread_pool::worker, this);
40 }
41
42 void destroy_threads() {
43 running = false;
44 {
45 const std::scoped_lock tasks_lock(tasks_mutex);
46 task_available_cv.notify_all();
47 }
48 for(concurrency_t i = 0; i < thread_count; ++i) threads[i].join();
49 }
50
51 [[nodiscard]] static concurrency_t determine_thread_count(const concurrency_t thread_count_) { return thread_count_ > 0 ? thread_count_ : std::thread::hardware_concurrency() > 0 ? std::thread::hardware_concurrency() : 1; }
52
53 void worker() {
54 while(running) {
55 std::unique_lock tasks_lock(tasks_mutex);
56 task_available_cv.wait(tasks_lock, [this] { return !tasks.empty() || !running; });
57 if(running) {
58 std::function<void()> task;
59 task = std::move(tasks.front());
60 tasks.pop();
61 tasks_lock.unlock();
62 task();
63 tasks_lock.lock();
64 --tasks_total;
65 if(waiting) task_done_cv.notify_one();
66 }
67 }
68 }
69
70public:
71 explicit thread_pool(const concurrency_t thread_count_ = 0)
72 : thread_count(determine_thread_count(thread_count_))
73 , threads(std::make_unique<std::thread[]>(determine_thread_count(thread_count_))) { create_threads(); }
74
76 wait_for_tasks();
77 destroy_threads();
78 }
79
80 [[nodiscard]] concurrency_t get_thread_count() const { return thread_count; }
81
82 template<typename F, typename T1, typename T2, typename T = std::common_type_t<T1, T2>> void push_loop(T1 first_index_, T2 index_after_last_, F&& loop, size_t num_blocks = 0) {
83 T first_index = static_cast<T>(first_index_);
84 T index_after_last = static_cast<T>(index_after_last_);
85 if(num_blocks == 0) num_blocks = thread_count;
86 if(index_after_last < first_index) std::swap(index_after_last, first_index);
87 const size_t total_size = static_cast<size_t>(index_after_last - first_index);
88 size_t block_size = total_size / num_blocks;
89 if(block_size == 0) {
90 block_size = 1;
91 num_blocks = (total_size > 1) ? total_size : 1;
92 }
93 if(total_size > 0) { for(size_t i = 0; i < num_blocks; ++i) push_task(std::forward<F>(loop), static_cast<T>(i * block_size) + first_index, (i == num_blocks - 1) ? index_after_last : (static_cast<T>((i + 1) * block_size) + first_index)); }
94 }
95
96 template<typename F, typename T> void push_loop(const T index_after_last, F&& loop, const size_t num_blocks = 0) { push_loop(0, index_after_last, std::forward<F>(loop), num_blocks); }
97
98 template<typename F, typename... A> void push_task(F&& task, A&&... args) {
99 {
100 const std::function<void()> task_function = std::bind(std::forward<F>(task), std::forward<A>(args)...);
101 const std::scoped_lock tasks_lock(tasks_mutex);
102 tasks.push(task_function);
103 ++tasks_total;
104 }
105 task_available_cv.notify_one();
106 }
107
108 template<typename F, typename... A, typename R = std::invoke_result_t<std::decay_t<F>, std::decay_t<A>...>> [[nodiscard]] std::future<R> submit(F&& task, A&&... args) {
109 std::function<R()> task_function = std::bind(std::forward<F>(task), std::forward<A>(args)...);
110 std::shared_ptr<std::promise<R>> task_promise = std::make_shared<std::promise<R>>();
111 push_task([task_function, task_promise] {
112 try {
113 if constexpr(std::is_void_v<R>) {
114 std::invoke(task_function);
115 task_promise->set_value();
116 }
117 else { task_promise->set_value(std::invoke(task_function)); }
118 }
119 catch(...) {
120 try { task_promise->set_exception(std::current_exception()); }
121 catch(...) {}
122 }
123 });
124 return task_promise->get_future();
125 }
126
128 if(waiting) return;
129 waiting = true;
130 std::unique_lock tasks_lock(tasks_mutex);
131 task_done_cv.wait(tasks_lock, [this] { return (tasks_total == 0); });
132 waiting = false;
133 }
134};
135
136#endif
Definition thread_pool.hpp:16
thread_pool(const concurrency_t thread_count_=0)
Definition thread_pool.hpp:71
void push_task(F &&task, A &&... args)
Definition thread_pool.hpp:98
void wait_for_tasks()
Definition thread_pool.hpp:127
void push_loop(T1 first_index_, T2 index_after_last_, F &&loop, size_t num_blocks=0)
Definition thread_pool.hpp:82
~thread_pool()
Definition thread_pool.hpp:75
std::future< R > submit(F &&task, A &&... args)
Definition thread_pool.hpp:108
concurrency_t get_thread_count() const
Definition thread_pool.hpp:80
void push_loop(const T index_after_last, F &&loop, const size_t num_blocks=0)
Definition thread_pool.hpp:96