线程池是实际开发中提高软件性能和稳定性的一种基本手段。可以想一下,如果程序中不用多线程,那执行效率会很低,如果运行线程太多,操作系统又吃不消,程序性能和稳定性会收到威胁。所以使用线程池技术诞生了,其既可以提高程序执行效率,又能将其性能维护在一个可控的范围内。以下是C++线程池实现的一种方式。
思路:根据线程最大个数和最小个数创建数量范围可控的线程,默认创建最小线程数量为CPU核数,最大为最小线程数的2倍。任务队列存放线程要执行的任务。类图如下:
目录
1. 优先队列
2.系统信息相关
3.辅助类
4.线程池类
5.测试代码
6.运行结果分析
1. 优先队列
由于封装的线程池支持执行优先级排序,所以执行任务使用std::priority_queue优先队列进行存储。std::priority_queue是stl中一种常用容器适配器,另外常见的还有 stack、queue,其是将std::vector或者std::deque使用堆结构进行封装。
头文件:<queue>
template<
class T,
class Container = std::vector<T>,
class Compare = std::less<typename Container::value_type>
>
class priority_queue;
注意:第三个参数是元素比较器,当函数返回为true时就会替换堆顶元素。默认的std::less表明优先队列默认使用大根堆进行存储。
优先队列内部的实现需要依赖基础容器,该容器应可通过随机访问[i]和迭代器Iterator访问,并需要支持以下操作
empty( )
size( )
front( )
push_back( )
pop_back( )
deque和vector这两个基础容器支持以上操作
所以在默认情况下,如果未为priority_queue指定基础容器类,则将使用vector。
关于std::priority_queue使用,有如下测试程序
#include <functional>
#include <queue>
#include <vector>
#include <iostream>
template<typename T>
void print_queue(T& q) {
while(!q.empty()) {
std::cout << q.top() << " ";
q.pop();
}
std::cout << '\n';
}
//大根堆比较器
struct MyCmp {
bool operator()(const int& a, const int& b)
{
return a < b;
}
};
//小根堆比较器
struct MyCmp1 {
bool operator()(const int& a, const int& b)
{
return a > b;
}
};
int main() {
std::priority_queue<int> queue1;
for(int n : {11,88,55,66,33,44,0,99,77,22}) {
queue1.push(n);
}
print_queue(queue1);
std::priority_queue<int, std::vector<int>, std::greater<int> > queue2;
for(int n : {11,88,55,66,33,44,0,99,77,22}) {
queue2.push(n);
}
print_queue(queue2);
auto cmp = [](const int &left, const int &right) {return left < right;};
//decltype 是 C++11 新增的一个关键字,它和 auto 的功能一样,都用来在编译时期进行自动类型推导。
std::priority_queue<int, std::vector<int>, decltype(cmp)> queue3(cmp);
for(int n : {11,88,55,66,33,44,0,99,77,22}) {
queue3.push(n);
}
print_queue(queue3);
std::priority_queue<int, std::vector<int>, MyCmp> queue4;
for(int n : {11,88,55,66,33,44,0,99,77,22}) {
queue4.push(n);
}
print_queue(queue4);
std::priority_queue<int, std::vector<int>, MyCmp1> queue5;
for(int n : {11,88,55,66,33,44,0,99,77,22}) {
queue5.push(n);
}
print_queue(queue5);
}
运行结果如下:
2.系统信息相关
以下类是为了获取系统CPU核数,物理内存数,CPU架构信息的类。
sysinfo.h
#pragma once
#include <string>
#include "uncopyable.h"
class SysInfo {
DECLARE_STATIC(SysInfo);
public:
/// get logic cpu number
static uint32_t get_logical_cpu_number();
static uint64_t get_physical_memory_size();
static bool get_kernel_info(std::string& kernerl_info);
static bool get_machine_architecture(std::string& arch_info);
};
sysinfo.cpp
#include <unistd.h>
#include <sys/utsname.h>
#include "sysinfo.h"
/// internal
static const struct utsname* get_uname_info()
{
static struct utsname uts;
if (uname(&uts) == -1)
return NULL;
return &uts;
}
/// SysInfo interfaces
uint32_t SysInfo::get_logical_cpu_number()
{
return sysconf(_SC_NPROCESSORS_ONLN);
}
uint64_t SysInfo::get_physical_memory_size()
{
return sysconf(_SC_PHYS_PAGES) * (uint64_t)sysconf(_SC_PAGESIZE);
}
bool SysInfo::get_kernel_info(std::string& kernerl_info)
{
const struct utsname* uts = get_uname_info();
if (uts != NULL) {
kernerl_info.assign(uts->release);
return true;
}
return false;
}
bool SysInfo::get_machine_architecture(std::string& arch_info)
{
const struct utsname* uts = get_uname_info();
if (uts != NULL) {
arch_info.assign(uts->machine);
return true;
}
return false;
}
3.辅助类
该辅助类是为了阻止类拷贝的工具类。
uncopyable.h
#pragma once
namespace uncopyable
{
class Uncopyable
{
protected:
Uncopyable() {}
~Uncopyable() {}
private:
Uncopyable(const Uncopyable&);
const Uncopyable& operator=(const Uncopyable&);
};
}
typedef uncopyable::Uncopyable Uncopyable;
#define DECLARE_UNCOPYABLE(Class) \
private: \
Class(const Class&); \
Class& operator=(const Class&)
#define DECLARE_STATIC(Class) \
private: \
Class() = delete; \
~Class() = delete
4.线程池类
线程池实现类:
threadpool.h
#pragma once
#include <map>
#include <memory>
#include <queue>
#include <vector>
#include <list>
#include <string>
#include <functional>
#include <mutex>
#include <thread>
#include <condition_variable>
#include <atomic>
#include "uncopyable.h"
class ThreadPool
{
DECLARE_UNCOPYABLE(ThreadPool);
public:
typedef std::function<void ()> TaskHandler;
public:
explicit ThreadPool(
int min_num_threads = -1,
int max_num_threads = -1,
const std::string& name = "");
~ThreadPool();
public:
//设置线程最小个数
void set_min_thread_number(int num_threads);
//设置线程最大个数
void set_max_thread_number(int num_threads);
//增加线程任务
uint64_t add_task(const TaskHandler& callback);
uint64_t add_task(const TaskHandler& callback, int priority);
//取消任务
bool cancel_task(uint64_t task_id);
//终止
void terminate();
void clear_tasks();
void wait_for_idle();
void get_stats() const;
private:
struct Task
{
Task(const TaskHandler& entry,
uint64_t id,
int priority) :
on_schedule(entry),
id(id),
priority(priority),
is_canceled(false) {}
void set_schedule_timeout_flag(bool flag)
{
is_canceled = flag;
}
void set_cancel_flag(bool flag)
{
is_canceled = flag;
}
bool check_cancel_flag() const
{
return is_canceled;
}
TaskHandler on_schedule;
uint64_t id; // task id
int priority; // task priority, lower is better
bool is_canceled; // task flag, schedule timeout or canceled
std::mutex task_lock; // task inner lock
};
struct ThreadContext
{
ThreadContext() :
waiting_timer_id(0),
is_waiting_timeout(false) {}
void set_waiting_timeout_flag(bool flag)
{
is_waiting_timeout = flag;
}
bool check_waiting_timeout_flag() const
{
return is_waiting_timeout;
}
std::shared_ptr<std::thread> thread;
uint64_t waiting_timer_id;
bool is_waiting_timeout;
};
typedef std::map<uint64_t, std::shared_ptr<Task> > TaskMap;
typedef std::list<ThreadContext*> ThreadList;
private:
void work(ThreadContext* thread);
private:
/// @brief auto generate a new task id
uint64_t new_task_id();
uint64_t add_task_internal(const TaskHandler& callback,
int priority);
bool dequeue_task_in_lock(std::shared_ptr<Task>& task);
bool need_new_thread() const;
bool need_shrink_thread() const;
// expand a thread into pool
void expand_thread();
private:
struct TaskCompare
{
bool operator()(const std::shared_ptr<Task>& a,
const std::shared_ptr<Task>& b)
{
return a->priority > b->priority;
}
};
private:
int m_min_num_threads;
int m_max_num_threads;
// current threads number
std::atomic<int> m_num_threads;
// current on-busy threads number
std::atomic<int> m_num_busy_threads;
// tasks container
TaskMap m_tasks;
// exit flags
volatile bool m_exit;
mutable std::mutex m_lock;
// all threads are free now
std::condition_variable m_exit_cond;
// tasks been requested
std::condition_variable m_task_cond;
// not suitable for std::vector since uncopyable structures
// list with m_num_threads elements running with each work thread routine
ThreadList m_threads;
// tasks queue,此处优先队列使用小根堆,priority越小越先被执行
std::priority_queue<
std::shared_ptr<Task>,
std::vector<std::shared_ptr<Task> >,
TaskCompare> m_task_queue;
};
threadpool.cpp
#include <cstddef>
#include <iostream>
#include <string>
#include <atomic>
#include <chrono>
#include "threadpool.h"
#include "sysinfo.h"
using namespace std;
static uint64_t s_thread_name_index = 0;
ThreadPool::ThreadPool(
int min_num_threads, int max_num_threads, const std::string& name) :
m_min_num_threads(0),
m_max_num_threads(0),
m_num_threads(0),
m_num_busy_threads(0),
m_exit(false)
{
if (min_num_threads <= 0) {
m_min_num_threads = SysInfo::get_logical_cpu_number();
} else {
m_min_num_threads = min_num_threads;
}
if (max_num_threads < m_min_num_threads) {
m_max_num_threads = 2 * m_min_num_threads;
} else {
m_max_num_threads = max_num_threads;
}
std::unique_lock<std::mutex> lock(m_lock);
for (int i = 0; i < m_min_num_threads; i++) {
ThreadContext* thread = new ThreadContext();
thread->thread.reset(
new std::thread(std::bind(
&ThreadPool::work, this, thread)));
m_threads.push_back(thread);
m_num_threads++;
}
s_thread_name_index += m_min_num_threads;
}
ThreadPool::~ThreadPool()
{
terminate();
}
void ThreadPool::terminate()
{
std::cout << __LINE__ << " " << __FUNCTION__ << std::endl;
if (m_exit) {
return;
}
{
std::unique_lock<std::mutex> lock(m_lock);
m_exit = true;
// send signal to all busy threads to exit
for (int i = 0; i < m_num_threads; i++) {
m_task_cond.notify_all();
}
// wait until all busy threads exited
while (m_num_busy_threads > 0) {
m_exit_cond.wait(lock);
}
}
for (auto it = m_threads.begin(); it != m_threads.end(); it++) {
if ((*it)->thread->joinable()) {
(*it)->thread->join();
}
}
// threads clear
while (!m_threads.empty()) {
ThreadContext* thread = m_threads.front();
m_threads.pop_front();
thread->thread.reset();
delete thread;
}
m_num_threads = 0;
m_num_busy_threads = 0;
// tasks clear
clear_tasks();
}
void ThreadPool::clear_tasks()
{
std::unique_lock<std::mutex> lock(m_lock);
while (!m_task_queue.empty()) {
m_task_queue.pop();
}
}
void ThreadPool::wait_for_idle()
{
if (m_exit) {
return;
}
for (;;) {
{
std::unique_lock<std::mutex> lock(m_lock);
if (m_task_queue.empty() && m_num_busy_threads == 0) {
return;
}
}
//阻塞1秒
this_thread::sleep_for(chrono::seconds(1));
}
}
void ThreadPool::set_min_thread_number(int num_threads)
{
if (m_exit) {
return;
}
if (num_threads <= 0) {
m_min_num_threads = SysInfo::get_logical_cpu_number();
} else {
m_min_num_threads = num_threads;
}
}
void ThreadPool::set_max_thread_number(int num_threads)
{
if (m_exit) {
return;
}
if (num_threads < m_min_num_threads) {
m_max_num_threads = 2 * m_min_num_threads;
} else {
m_max_num_threads = num_threads;
}
}
uint64_t ThreadPool::add_task(const TaskHandler& callback)
{
return add_task_internal(callback, 10);
}
uint64_t ThreadPool::add_task(const TaskHandler& callback, int priority)
{
return add_task_internal(callback, priority);
}
bool ThreadPool::cancel_task(uint64_t task_id)
{
std::unique_lock<std::mutex> lock(m_lock);
TaskMap::iterator it = m_tasks.find(task_id);
if (it != m_tasks.end()) {
it->second->set_cancel_flag(true);
return true;
} else {
return false;
}
}
void ThreadPool::get_stats() const
{
std::unique_lock<std::mutex> lock(m_lock);
std::cout << "######## ThreadPool Stats ################" << std::endl;
std::cout << "m_min_num_threads:" << m_min_num_threads << std::endl;
std::cout << "m_max_num_threads:" << m_max_num_threads << std::endl;
std::cout << "m_num_threads:" << m_num_threads << std::endl;
std::cout << "m_num_busy_threads:" << m_num_busy_threads << std::endl;
std::cout << "m_threads size:" << m_threads.size() << std::endl;
std::cout << "##########################################" << std::endl;
}
// working threads logic
void ThreadPool::work(ThreadContext* thread)
{
m_num_busy_threads++;
for (;;) {
std::shared_ptr<Task> task;
{
std::unique_lock<std::mutex> lock(m_lock);
if (m_exit || thread->check_waiting_timeout_flag()) {
break;
}
if (!dequeue_task_in_lock(task)) {
m_num_busy_threads--;
m_task_cond.wait(lock);
m_num_busy_threads++;
continue;
}
}
if (!task) {
continue;
}
// execute task
if (task->on_schedule) {
task->on_schedule();
}
}
// quit native-thread, move to freethread list
{
std::unique_lock<std::mutex> lock(m_lock);
m_num_threads--;
m_num_busy_threads--;
if (m_num_busy_threads == 0) {
m_exit_cond.notify_all();
}
}
}
/// private methods
static std::atomic<size_t> s_task_id(0);
uint64_t ThreadPool::new_task_id()
{
return static_cast<uint64_t>(++s_task_id);
}
uint64_t ThreadPool::add_task_internal(const TaskHandler& callback,
int priority)
{
if (m_exit) {
return 0;
}
uint64_t id = new_task_id();
std::shared_ptr<Task> task(new Task(callback, id, priority));
{
std::unique_lock<std::mutex> lock(m_lock);
// check whether need expand threads
if (need_new_thread()) {
expand_thread();
}
// add to task map
m_tasks[id] = task;
// push into priority task queue
m_task_queue.push(task);
m_task_cond.notify_all();
}
return id;
}
bool ThreadPool::need_new_thread() const
{
if (m_num_threads >= m_max_num_threads) {
return false;
}
if (m_num_threads < m_min_num_threads ||
m_num_threads == m_num_busy_threads) {
return true;
}
return false;
}
bool ThreadPool::need_shrink_thread() const
{
if (m_num_threads > m_min_num_threads) {
return true;
}
return false;
}
void ThreadPool::expand_thread()
{
ThreadContext* thread = new ThreadContext();
// stored in threadcontext.thread use scoped_ptr
thread->thread.reset(
new std::thread(std::bind(
&ThreadPool::work, this, thread)));
// add threads into busythreads list
m_threads.push_back(thread);
m_num_threads++;
}
bool ThreadPool::dequeue_task_in_lock(std::shared_ptr<Task>& task)
{
if (m_task_queue.empty()) {
return false;
}
task = m_task_queue.top();
// remove from task queue
m_task_queue.pop();
// remove task map
m_tasks.erase(task->id);
return true;
}
5.测试代码
#include <iostream>
#include <ctime>
#include <cstdlib>
#include <chrono>
#include "threadpool.h"
using namespace std;
class HttpClient {
public:
static HttpClient* getInstance() {
static HttpClient s_instance;
return &s_instance;
}
~HttpClient() {
}
bool ConnServer() {
int iReqId = 1;
int timeout_in_ms = 100;
int call_timeout_in_ms = 100;
auto execute_fun = [this, iReqId, timeout_in_ms]() {
std::cout << "HttpClient:: begin " << iReqId << " " << timeout_in_ms << std::endl;
this->UseTimeFun();
std::cout << "HttpClient:: end" << std::endl;
m_thread_pool.get_stats();
};
iReqId++;
auto execute_fun1 = [this, iReqId, timeout_in_ms]() {
std::cout << "HttpClient:: begin " << iReqId << " " << timeout_in_ms << std::endl;
this->UseTimeFun();
std::cout << "HttpClient:: end" << std::endl;
m_thread_pool.get_stats();
};
m_thread_pool.add_task(execute_fun, 1);
m_thread_pool.add_task(execute_fun1, 0);
return true;
}
private:
void UseTimeFun() {
srand((unsigned)time(NULL));
this_thread::sleep_for(chrono::seconds(rand()%10));
}
private:
ThreadPool m_thread_pool;
};
class DBClient {
public:
static DBClient* getInstance() {
static DBClient s_instance;
return &s_instance;
}
~DBClient() {
}
bool ConnServer() {
std::string uname = "mysql";
std::string passwd = "mysql";
int timeout_in_ms = 100;
int call_timeout_in_ms = 100;
auto execute_fun = [this, uname, passwd]() {
std::cout << "DBClient:: begin " << uname << " " << passwd << std::endl;
this->OperFun();
std::cout << "DBClient:: end " << std::endl;
m_thread_pool.get_stats();
};
m_thread_pool.add_task(execute_fun, 1);
return true;
}
private:
void OperFun() {
srand((unsigned)time(NULL));
this_thread::sleep_for(chrono::seconds(rand()%10));
}
private:
ThreadPool m_thread_pool;
};
int main(int argc, char* argv[]) {
HttpClient::getInstance()->ConnServer();
DBClient::getInstance()->ConnServer();
getchar();
return 0;
}
Makefile
app: useThreadPool
#说明:$^代表依赖项
useThreadPool: main.cpp threadpool.cpp sysinfo.cpp
g++ -g $^ -o useThreadPool -lpthread
clean:
-rm useThreadPool -f
6.运行结果分析
需要开启线程的函数都是执行比较耗时的,测试程序使用sleep来模拟http请求和连接数据库操作,并将其执行放入线程池中执行。
通过运行可以看出,本人电脑12核,所以默认最小线程和最大线程为12,24。然后启动最小数量线程开始工作,默认忙碌线程数量为0,
HttpClient向其中加入耗时任务execute_fun和execute_fun1后,忙碌线程数量变成2,DBClient向其中加入耗时任务