cpp11通用线程池

通用线程池

通用线程池和简单线程池的区别就是,简单线程池对交给线程的任务函数入参和返回值有类型要求,而通用的没有。在本文中,将讲述和简单线程池提交函数的区别。本文需要一定的 C++ 模板泛型基础。

和简单线程池的区别

简单线程池

在简单线程池中,增加任务给线程池的函数如下:

1
2
3
4
5
6
typedef std::function<void()> task_t;
void thread_pool::add_task(const task_t& task) {
std::unique_lock<std::mutex> lock(mutex_);
tasks_.push(task);
wake_cond_.notify_one(); // 只唤醒一个线程
}

可见,task_t 就是我们能向线程池提交的函数类型。它是一个 function 类型,返回值为 void,并且没有入参。那如果我们想提交一个返回值为 int,也有 int 作为入参的函数这样就行不通了。

通用线程池

于是我们将它略微改进下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
// 向线程池加入任务
template<class F, class... Args>
auto add_task(F&& f, Args&&... args) -> std::future<decltype(f(args...))> {
using return_type = decltype(f(args...));
auto task_ptr = std::make_shared<std::packaged_task<return_type()>>( // 因为 packaged_task 无法拷贝构造,所以用 make_shared
std::bind(std::forward<F>(f), std::forward<Args>(args)...)
);

std::future<return_type> res = task_ptr->get_future();
{
std::unique_lock<std::mutex> lock(mutex_);
tasks_.emplace([task_ptr]() {
(*task_ptr)();
});
}
wake_cond_.notify_one();
return res;
}

这个函数有两个模板参数,第一个模板参数 F 代表的是一个函数,第二个模板参数 Args 代表的是这个函数执行时的参数。该函数的返回值是一个 std::future 对象,可以通过 future.get() 来获得执行函数的返回值。具体用法可以见 future用法

因为不知道函数返回值是什么类型,所以需要用到 decltype,在编译时自动确定调用函数的返回值类型。届时从 future 对象中 get 到的数据类型就可以拿到了。在本函数中,先定义了一个指向 packaged_task 的智能指针(packaged_task 本质是个仿函数),之后把这个仿函数包装一层,让它变为一个返回值为 return_type,并且没有入参的函数。之所以用指针,是因为 packaged_task 无法进行拷贝构造,所以下面没法调用。之后在放入执行线程队列前,用 lambda 表达式封装它,让它变为返回值为 void 的函数,把它放到执行线程的队列中。最后该函数返回 packaged_task 对应的 future 对象,方便调用者拿到返回值。

值得注意的是,该函数因为是有模板参数,所以这个实现应该放在头文件中。原因详见:cpp模板类定义放在头文件原因

完整代码

以下代码是对 cpp11简单线程池 基础上做的改进:

main.cc

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
#include <iostream>
#include <chrono>
#include <future>
#include "thread_pool.h"

std::mutex g_screen_mutex; // 向终端打印信息的锁

int test(int a, int b) {
std::this_thread::sleep_for(std::chrono::seconds(1));
std::lock_guard<std::mutex> lock(g_screen_mutex);
std::cout << "test() at thread [ " << std::this_thread::get_id() << "] output [" << a + b << "]" << std::endl;
return a + b;
}

int main() {
thread_pool thread_pool;
std::vector<std::future<int>> results;
for(int i = 0; i < 5 ; i++) {
auto fu = thread_pool.add_task(test, i, 2);
results.emplace_back(std::move(fu));
}
getchar(); // 等待,不要让主进程退出
for (int i = 0; i < 5; i++) {
std::cout<< "result of number " << i << " :" << results[i].get() << std::endl;
}
return 0;
}

thread_pool.h:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
#ifndef _THREAD_POOL_H_
#define _THREAD_POOL_H_

#include <vector>
#include <queue>
#include <thread>
#include <functional>
#include <mutex>
#include <condition_variable>

class thread_pool {
public:
// 定义为一个函数类型,返回值为 void,没有入参
typedef std::function<void()> task_t;

thread_pool(int init_size = 3);
~thread_pool();
// 停止线程池
void stop();
// 向线程池加入任务
template<class F, class... Args>
auto add_task(F&& f, Args&&... args) -> std::future<decltype(f(args...))> {
using return_type = decltype(f(args...));
auto task_ptr = std::make_shared<std::packaged_task<return_type()>>(
std::bind(std::forward<F>(f), std::forward<Args>(args)...)
);
std::future<return_type> res = task_ptr->get_future();
{
std::unique_lock<std::mutex> lock(mutex_);
tasks_.emplace([task_ptr]() {
(*task_ptr)();
});
}
wake_cond_.notify_one();
return res;
}
private:
thread_pool(const thread_pool&) = delete; // 禁止复制拷贝
const thread_pool& operator=(const thread_pool&) = delete;
// 线程池启动函数
void start();
// 每个线程的循环函数
void thread_loop();
// 从线程池里拿一个线程
task_t take();

int init_threads_size_; // 初始线程数量
std::vector<std::thread*> threads_; // 已经创建的线程列表
std::queue<task_t> tasks_; // 待执行任务列表
std::mutex mutex_; // 操作线程池共有变量之前先上锁
std::condition_variable wake_cond_; // 唤醒线程的条件
bool is_started_; // 线程池是否已经启动
};
#endif

thread_pool.cc:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
#include <assert.h>
#include <iostream>
#include <future>
#include "thread_pool.h"

thread_pool::thread_pool(int init_size)
: init_threads_size_(init_size), mutex_(), wake_cond_(), is_started_(false) {
start();
}

thread_pool::~thread_pool() {
if (is_started_) {
stop();
}
}

void thread_pool::start() {
assert(threads_.empty());
is_started_ = true;
threads_.reserve(init_threads_size_);
for (int i = 0; i < init_threads_size_; ++i) {
threads_.push_back(new std::thread(std::bind(&thread_pool::thread_loop, this)));
}
}

void thread_pool::stop() {
std::cout << "thread_pool::stop() stop." << std::endl;
{
std::unique_lock<std::mutex> lock(mutex_);
is_started_ = false;
wake_cond_.notify_all();
std::cout << "thread_pool::stop() notifyAll()." << std::endl;
}

for (auto thread : threads_) {
thread->join();
delete thread;
}
threads_.clear();
}

void thread_pool::thread_loop() {
std::cout << "thread_pool::threadLoop() tid : " << std::this_thread::get_id() << " start." << std::endl;
while (is_started_) {
task_t task = take();
if (task) {
task();
}
}
std::cout << "thread_pool::threadLoop() tid : " << std::this_thread::get_id() << " exit." << std::endl;
}

thread_pool::task_t thread_pool::take() {
std::unique_lock<std::mutex> lock(mutex_);
// 使用 while 循环,防止假唤醒
while (tasks_.empty() && is_started_) {
std::cout << "thread_pool::take() tid : " << std::this_thread::get_id() << " wait." << std::endl;
wake_cond_.wait(lock);
}

std::cout << "thread_pool::take() tid : " << std::this_thread::get_id() << " wakeup." << std::endl;
task_t task;
size_t size = tasks_.size();
if (!tasks_.empty() && is_started_) {
task = tasks_.front();
tasks_.pop();
assert(size - 1 == tasks_.size());
}
if (task != nullptr) {
std::cout << "thread_pool::take() tid : " << std::this_thread::get_id() << " took a task!" << std::endl;
}
return task;
}

运行结果:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
@└────> # ./test.out 
thread_pool::threadLoop() tid : 139812637472512 start.
thread_pool::take() tid : 139812637472512 wakeup.
thread_pool::take() tid : 139812637472512 took a task!
thread_pool::threadLoop() tid : 139812629079808 start.
thread_pool::take() tid : 139812629079808 wakeup.
thread_pool::take() tid : 139812629079808 took a task!
thread_pool::threadLoop() tid : 139812620687104 start.
thread_pool::take() tid : 139812620687104 wakeup.
thread_pool::take() tid : 139812620687104 took a task!
test() at thread [ 139812637472512] output [2]
thread_pool::take() tid : 139812637472512 wakeup.
thread_pool::take() tid : 139812637472512 took a task!
test() at thread [ 139812629079808] output [3]
thread_pool::take() tid : 139812629079808 wakeup.
thread_pool::take() tid : 139812629079808 took a task!
test() at thread [ 139812620687104] output [4]
thread_pool::take() tid : 139812620687104 wait.
test() at thread [ 139812637472512] output [5]
thread_pool::take() tid : 139812637472512 wait.
test() at thread [ 139812629079808] output [6]
thread_pool::take() tid : 139812629079808 wait.
(键入回车)
result of number 0 :2
result of number 1 :3
result of number 2 :4
result of number 3 :5
result of number 4 :6
thread_pool::stop() stop.
thread_pool::stop() notifyAll().
thread_pool::take() tid : 139812620687104 wakeup.
thread_pool::threadLoop() tid : 139812620687104 exit.
thread_pool::take() tid : 139812637472512 wakeup.
thread_pool::threadLoop() tid : 139812637472512 exit.
thread_pool::take() tid : 139812629079808 wakeup.
thread_pool::threadLoop() tid : 139812629079808 exit.