C++写一个线程池
文章目录
- C++写一个线程池
- 设计思路
- 测试数据的实现
- 任务类的实现
- 线程池类的实现
- 线程池构造函数
- 线程池入口函数
- 队列中取任务
- 添加任务函数
- 线程池终止函数
- 源码
之前用C语言写了一个线程池,详情请见:
C语言写一个线程池
这次换成C++了!由于C++支持泛型编程,所以代码的灵活性提高了不知道多少倍!!!!!
设计思路
线程池的原理就不解释了,这次主要来讲一下我使用C++进行面向过程、面向对象、泛型设计时的思想。
线程池的工作原理是存在一个具有多个线程的空间,我们对这些线程进行一个统一的管理(利用C++提供的线程类)。然后具有一个存放任务的队列,这些线程依次从中取出任务然后执行。
从上面的过程中发现可以将线程池作为一个对象来进行设计。这个对象中的成员至少有:
- 存放n个线程对象的空间,可以使用
vector<std::thread>
进行管理。 - 标记每一个线程的工作状态的容器,这里可以使用
unordered_map<<std::thread::id, bool>
来进行管理。 - 存放任务的队列
- 等等…(在设计的过程中会发现)
我希望设计一个很万能的线程池,即可以接受任何任务,我只需要将对应的函数和参数传入进队列中,就可以让线程自动执行。
因此,设计一个任务类是必不可少的。因为,我们从函数外部传进去的函数和参数不一定相同,而且不同功能的任务之间没有一个合适的管理方式,因此我们需要设计一个任务类来兼容不同参数,并且将参数和工作函数绑定到一块的情况。
测试数据的实现
我个人比较喜欢在设计代码之前假设他已经设计好,然后先写出他的测试方法和数据,之后一点点来实现,这次我选择的测试方法是求1~50000000中的素数个数,不使用素数筛:
先实现判断素数的功能函数:
// 判断素数n是不是素数
bool is_prime(int n) {
for (int i = 2; i * i < n; ++i) {
if (n % i == 0) return false;
}
return true;
}
// 求出 start 到 end 返回的素数个数
int prime_count(int start, int end) {
int ans = 0;
for (int i = start; i <= end; ++i) {
ans += is_prime(i);
}
return ans;
}
// 需要传入到线程池内的函数,求出 l 到 r 的素数个数然后保存在 ret 中
void worker(int l, int r, int &ret) {
cout << this_thread::get_id() << ": begin" << endl;
ret = prime_count(l, r);
cout << this_thread::get_id() << ": end" << endl;
return ;
}
下面是主函数:
int main() {
#define MAX_N 5000000 // 这里假设需要处理10次,因此每次处理五百万的数据
thread_pool tp(5); // 假设传入的参数是线程池中的线程个数
int ret[10]; // 存放10次结果
for (int i = 0. j = 1; i < 10; ++i, j += MAX_N) {
tp.add_task(worker, j, j + MAX_N - 1, ref(ret[i])); // ref是用来传入引用的
}
tp.stop(); // 停止线程池的运转
int ans = 0; // 计算出结果
for (int i = 0; i < 10; ++i) {
ans += ret[i];
}
cout << "ans = " << ans << endl;
return 0;
}
任务类的实现
明确一下目的:
实现一个类,第一个参数是一个函数地址,后面为变参列表,该类会将函数与参数打包成一个新的函数,作为任务队列中的一个元素,当空闲线程将其取出之后,可以执行这个新的函数。
这个需要用到模板:
template <typename FUNC_T, typename ...ARGS>
class Task {
Task(FUNC_T func, ARGS... args) {
...
}
private:
...
};
绑定之后我们需要一个变量来存放这个函数,因此需要添加一个函数指针对象 function<void()>
使用 bind
函数进行绑定。
在给bind函数传入参数列表时,需要维持左右值原型,因此需要工具类 std::forward<ARGS>(args)...
来解析参数类型。
template <typename FUNC_T, typename ...ARGS>
class Task {
Task(FUNC_T func, ARGS... args) {
this->_func = bind(functionn, std::forward<ARGS>(args)...);
}
private:
std::function<void()> _func;
};
最后需要一个方法来运行这个函数:
别忘了析构函数。
template <typename FUNC_T, typename ...ARGS>
class Task {
Task(FUNC_T func, ARGS... args) {
this->_func = bind(functionn, std::forward<ARGS>(args)...);
}
~Task() {
delete _func;
}
void run() {
_func();
return ;
}
private:
std::function<void()> _func;
};
线程池类的实现
根据一开始的测试数据,发现线程池的操作对外就支持两个操作:
- 压入任务
- 停止
根据一开始所分析的:
- 存放n个线程对象的空间,可以使用
vector<std::thread>
进行管理。 - 标记每一个线程的工作状态的容器,这里可以使用
unordered_map<std::thread::id, bool>
来进行管理。 - 存放任务的队列
- 等等…(在设计的过程中会发现)
我们可以先将成员和已知的方法写上:
template <typename FUNC_T, typename ...ARGS>
class thread_pool {
public:
thread_pool() {}
template <typename FUNC_T, typename ...ARGS>
void add_task(FUNC_T func, ARGS... args) {...}
void stop() {...}
private:
std::vector<std::thread *> _threads;
unordered_map<std::thread::id, bool> _running;
std::queue<Task *> _tasks;
};
线程池构造函数
先来尝试完善一下构造函数,我们使用参数来控制线程池中的线程个数,默认线程数量我们可以设置成为1
创建出的线程空间放在堆区,因此使用 new
关键字来创建:
thread_pool(int n = 1) {
for(int i = 0; i < n; ++i) {
_threads[i] = new thread(&thread_pool::worker, this); // 别忘了内部方法的第一个隐藏参数this传入进去
}
}
线程池入口函数
这个时候我们发现,由于thread的构造函数需要传入一个需要运行的函数,因此发现又多了一个函数就是工作函数 worker
简单来说,这个函数的功能规定了所有线程的行为——从队列中取出任务并执行。
在工作函数内部,我们需要将该线程 id
记录下来(表示他是否存活),然后不断判断这个线程是否存活,如果存活就继续等待队列中的任务
这个工作函数的作用就是不断检查队列中是否有可以取出的任务,然后执行。
void worker() {
auto id = this_thread::get_id();
_running[id] = true; // 表示这个线程被记录下来,在运行状态
while (_running[id]) {
Task *t = get_task(); // 从队列中取任务,这里又诞生出一个新函数
t->run(); // 运行任务
delete t;
}
}
队列中取任务
可以发现又有了新的函数需求就是从队列中取出一个任务。
这个函数并不对外表现,所以应该设置为私有成员方法。
主要逻辑就是检查队列头部是否有任务对象,如果有的话就返回这个任务的地址。
由于队列是临界资源,所以需要上锁,此时不免又多了两个成员变量
std::mutex m_mutex;
std::condition_variable m_cond;
Task *get_task() {
unique_task<mutex> (m_mutex); // 上锁
while (_tasks.empty()) { // 如果队列为空则释放锁并等待队列被放入任务的条件
m_cond.wait(lock)
}
Task *t = _tasks.front();
_tasks.pop();
return t;
}
这样一来,我们就实现了线程的初始化以及任务的取出。
接下来,还剩下任务的压入,这个操作由 add_task()
实现,因此我们先来实现这个函数
添加任务函数
由于他也是访问临界资源,因此,也需要上锁,同时在添加成功之后释放一个信号。
同样的,这个函数需要在外部调用,因此设置成为共有成员方法:
template <typename FUNC_T, typename ...ARGS>
void add_task(FUNC_T func, ARGS... args) {
_tasks.push(new Task(func, std:forward<ARGS>(args)...));
lock.unlock();
m_cond.notify_one();
return ;
}
线程池终止函数
线程不再运行之后可以选择终止他们来节省计算机资源,因此这个函数是必不可少的,主要的操作方式如下,我们想队列中压入等于同于线程数量的特殊任务这个特殊任务会把线程标记为非活动的,然后后等到他们全部执行完任务后,再依次释放掉他们的资源。
void stop() {
for (int i = 0; i < _thread.size(); ++i) {
this->add_task(&thread_pool::stop_running, this); // 毒药方法
}
for(int i = 0; i < _thread_size(); ++i) {
_threads[i]->join();
}
for(int i = 0; i < _thread.size(); ++i) {
delete _threads[i];
_threads[i] = nullptr;
}
return ;
}
其中涉及到了 stop_running()
这个毒药方法,这个方法只在函数内部调用,因此我们把他设计成为私有成员方法。
这个函数的逻辑就是将当前线程标记为非活动的状态。
void stop_running() {
auto id = this_thread::get_id();
_running[id] = false;
return ;
}
到此为止,线程池的大部分功能就设计的差不多了,之后我又进行了一下细微的调整,相信读者应该自己也能读懂,这里就不过多解释了。
源码
#include <condition_variable>
#include <iostream>
#include <algorithm>
#include <vector>
#include <queue>
#include <stack>
#include <set>
#include <map>
#include <cstdio>
#include <cstdlib>
#include <ctype.h>
#include <cmath>
#include <string>
#include <sstream>
#include <functional>
#include <thread>
#include <mutex>
#include <condition_variable>
#define TEST_BEGINS(x) namespace x {
#define TEST_ENDS(x) } // end of namespace x
using namespace std;
TEST_BEGINS(thread_pool)
bool is_prime(int n) {
for (int i = 2; i * i <= n; ++i) {
if (n % i == 0) return false;
}
return true;
}
int prime_count(int start, int end) {
int ans = 0;
for (int i = start; i <= end; ++i) {
ans += is_prime(i);
}
return ans;
}
// 多线程入口函数
void worker(int l, int r, int &ret) {
cout << this_thread::get_id() << " : begin" << endl;
ret = prime_count(l, r);
cout << this_thread::get_id() << " : done" << endl;
return ;
}
class Task {
public:
template <typename FUNC_T, typename ...ARGS>
Task(FUNC_T function, ARGS ...args) {
this->func = bind(function, std::forward<ARGS>(args)...);
}
void run() {
func();
return ;
}
private:
function<void()> func;
};
class thread_pool {
public:
thread_pool(int n = 1) : _thread_size(n), _threads(n), starting(false) {
this->start();
}
~thread_pool() {
this->stop();
while (!_tasks.empty()) {
delete _tasks.front();
_tasks.pop();
}
return ;
}
/*
* 入口函数:不断从队列中取任务,然后运行,最后释放资源。
*/
void worker() {
auto id = this_thread::get_id();
_running[id] = true;
while (_running[id]) {
Task *t = get_task();
t->run();
delete t;
}
return ;
}
void start() {
if (starting == true) return ;
for (int i = 0; i < _thread_size; ++i) {
_threads[i] = new thread(&thread_pool::worker, this);
}
starting = true; // 标记线程池运行
return ;
}
void stop() {
if (starting == false) return ;
for (int i = 0; i < _threads.size(); ++i) {
this->add_task(&thread_pool::stop_running, this);
}
for (int i = 0; i < _threads.size(); ++i) {
_threads[i]->join();
}
for (int i = 0; i < _threads.size(); ++i) {
delete _threads[i];
_threads[i] = nullptr;
}
starting = false;
return ;
}
template <typename FUNC_T, typename ...ARGS>
void add_task(FUNC_T func, ARGS... args) {
unique_lock<mutex> lock(m_mutex);
_tasks.push(new Task(func, std::forward<ARGS>(args)...));
lock.unlock();
m_cond.notify_one();
return ;
}
private:
Task *get_task() {
unique_lock<mutex> lock(m_mutex);
while (_tasks.empty()) { // 避免虚假唤醒
m_cond.wait(lock);
}
Task *t = _tasks.front();
_tasks.pop();
return t;
}
void stop_running() {
auto id = this_thread::get_id();
_running[id] = false; // 毒药方法
return ;
}
bool starting;
int _thread_size; // 记录线程个数
std::mutex m_mutex; // 互斥锁
std::condition_variable m_cond; // 条件变量
vector<thread *> _threads; // 线程区
unordered_map<std::thread::id, bool> _running; // 线程活动标记
queue<Task *> _tasks; // 任务队列
};
int main() {
#define MAX_N 5000000
thread_pool tp(10);
int ret[10];
for (int i = 0, j = 1; i < 10; ++i, j += batch) {
tp.add_task(worker, j, j + MAX_N - 1, ref(ret[i]));
}
tp.stop();
int ans = 0;
for (auto x : ret) ans += x;
cout << ans << endl;
return 0;
}
TEST_ENDS(thread_pool)
int main() {
thread_pool::main();
return 0;
}
运行结果: