【项目实战】基于高并发服务器的搜索引擎
目录
- 【项目实战】基于高并发服务器的搜索引擎
- 搜索引擎部分代码
- index.html
- index.hpp
- log.hpp
- parser.cc(用于对网页的html文件切分且存储索引关系)
- searcher.hpp
- util.hpp
- http_server.cc(用于启动服务器和搜索引擎)
- 高并发服务器部分
- http.hpp
- server.hpp
- httplib.h(第三方库)
- cppjieba(第三方库)
- 目录结构
作者:爱写代码的刚子
时间:2024.4.24
前言:基于高并发服务器的搜索引擎,引用了第三方库cpp-httplib,cppjieba,项目的要点在代码注释中了
搜索引擎部分代码
index.html
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<meta http-equiv="X-UA-Compatible" content="IE=edge">
<script src="https://cdn.jsdelivr.net/npm/jquery@3.5.1/dist/jquery.min.js"></script>
<title>本地 boost 搜索引擎</title>
<style>
* {
margin: 0;
padding: 0;
box-sizing: border-box;
}
html, body {
height: 100%;
font-family: Arial, sans-serif;
}
.container {
width: 100%;
display: flex;
flex-direction: column;
align-items: center;
}
.title {
width: 100%;
background-color: #4e6ef2;
color: #fff;
text-align: center;
padding: 10px 0;
font-size: 24px;
font-weight: bold;
}
.search-container {
width: 100%;
background-color: #f2f2f2;
display: flex;
justify-content: center;
align-items: center;
padding: 20px 0;
position: relative;
}
.search-input {
width: calc(100% - 130px); /* 调整搜索框宽度 */
max-width: 300px; /* 设置最大宽度 */
height: 40px;
padding: 10px;
border: 1px solid #ccc;
border-radius: 20px;
font-size: 16px;
outline: none;
}
.search-btn {
width: 100px; /* 调整按钮宽度 */
height: 40px;
background-color: #4e6ef2;
color: #fff;
border: none;
border-radius: 20px;
font-size: 16px;
cursor: pointer;
position: absolute;
right: 10px;
}
.result-container {
width: 100%;
padding: 20px 0;
display: flex;
flex-direction: column;
align-items: center;
}
.result-item {
width: 90%; /* 修改为百分比宽度,更适应移动设备 */
max-width: 800px; /* 设置最大宽度 */
border: 1px solid #ccc;
border-radius: 5px;
padding: 10px;
margin-top: 10px;
}
.result-title {
font-size: 18px;
color: #4e6ef2;
text-decoration: none;
}
.result-desc {
font-size: 14px;
color: #333;
margin-top: 5px;
}
.result-url {
font-size: 12px;
color: #666;
margin-top: 5px;
}
</style>
</head>
<body>
<div class="container">
<div class="title">boost 搜索引擎</div>
<div class="search-container">
<input type="text" class="search-input" value="输入搜索关键字..." onfocus="if(this.value=='输入搜索关键字...') this.value='';" onblur="if(this.value=='') this.value='输入搜索关键字...';">
<button class="search-btn" onclick="Search()">搜索一下</button>
</div>
<div class="result-container">
<!-- 搜索结果动态生成 -->
</div>
</div>
<script>
function Search() {
let query = $(".search-input").val().trim();
if (query == '') {
return;
}
$.ajax({
type: "GET",
url: "/s?word=" + query,
dataType: "json",
success: function (data) {
BuildHtml(data);
$(".search-input").css("margin-top", "20px");
}
});
}
function BuildHtml(data) {
let result_container = $(".result-container");
result_container.empty();
if (!data || data.length === 0) {
result_container.append("<div>未找到相关结果</div>");
return;
}
for (let elem of data) {
let item = $("<div>", {class: "result-item"});
let title = $("<a>", {class: "result-title", href: elem.url, text: elem.title, target: "_blank"});
let desc = $("<div>", {class: "result-desc", text: elem.desc});
let url = $("<div>", {class: "result-url", text: elem.url});
title.appendTo(item);
desc.appendTo(item);
url.appendTo(item);
item.appendTo(result_container);
}
}
</script>
</body>
</html>
index.hpp
#pragma once
#include <iostream>
#include <vector>
#include <string>
#include <fstream>
#include <unordered_map>
#include <mutex>
#include "util.hpp"
#include "log.hpp"
namespace ns_index{
struct DocInfo{
std::string title;//文档标题
std::string content;//文档对应的去标签之后的内容
std::string url;//官网文档url
uint64_t doc_id; //文档的ID
};
struct InvertedElem{//倒排的元素
uint64_t doc_id;
std::string word;
int weight;
};
//倒排拉链
typedef std::vector<InvertedElem> InvertedList;
class Index{
private:
//正排索引的数据结构用数组,数组的下标天然是文档的ID
std::vector<DocInfo> forward_index;//正排索引
//倒排索引一定是一个关键字和一组(个)InvertedElem对应(关键字和倒排拉链的对应关系)
std::unordered_map<std::string , InvertedList>inverted_index;
private:
Index(){}//单例,但是不能delete
Index(const Index&) = delete;
Index& operator = (const Index&) = delete;
static Index *instance;
static std::mutex mtx;
public:
~Index(){}
public:
static Index* GetInstance()//多线程环境会存在线程安全
{
if(nullptr==instance)
{
mtx.lock();
if(nullptr==instance)
{
instance = new Index();
}
mtx.unlock();
}
return instance;
}
//根据doc_id找到文档内容
DocInfo* GetForwardIndex(uint64_t doc_id)
{
if(doc_id >= forward_index.size())
{
//std::cerr<<"doc_id out range,error!"<<std::endl;
LOG2(DEBUG,"doc_id out range,error!");
return nullptr;
}
return &forward_index[doc_id];
}
//根据关键字string,获得倒排拉链
InvertedList *GetInvertedList(const std::string &word)
{
auto iter = inverted_index.find(word);
if(iter==inverted_index.end())
{
//std::cerr<<word<<"have no InvertedList"<<std::endl;
LOG2(WARNING,"用户没搜到");
return nullptr;
}
return &(iter->second);
}
//根据去标签,格式化之后的文档,构建正排和倒排索引
//data/raw_html/raw.txt
bool BuildIndex(const std::string &input)//parse处理完毕的数据交给我(文件的路径)
{
std::ifstream in(input,std::ios::in | std::ios::binary);
if(!in.is_open()){
//std::cerr<<"sorry,"<<input<<"open error"<<std::endl;
LOG2(FATAL,"open error");
return false;
}
//读取文件
std::string line;//每一行是一个文件
int count = 0;
while(std::getline(in,line))
{
//建立正排索引
DocInfo* doc=BuildForwardIndex(line);
if(doc==nullptr)
{
//std::cerr<<"build"<<line<<"error"<<std::endl;//for debug
LOG2(DEBUG,"建立正排索引错误");
continue;
}
BuildInvertedIndex(*doc);
count++;
if(count % 50==0)
{
//std::cout<< "当前已经建立的索引文档:"<<count <<std::endl;
LOG2(NORMAL,"当前已经建立的索引文档: " + std::to_string(count));
}
}
return true;
}
private:
DocInfo *BuildForwardIndex(const std::string &line)
{
//1. 解析line,字符串切分 line -> 3个string,(title、content、url)
std::vector<std::string> results;
const std::string sep ="\3";//行内分隔符
ns_util::StringUtil::Split(line,&results,sep);
if(results.size()!=3){
return nullptr;
}
//2. 字符串进行填充到DoInfo
DocInfo doc;
doc.title = results[0];
doc.content = results[1];
doc.url = results[2];
doc.doc_id = forward_index.size();//先进行保存,再插入,对应的id就是当前doc在vector下的下标
//3. 插入到正排索引的vector
forward_index.push_back(std::move(doc));//doc.html文件内容会比较大,避免拷贝应使用move
return &forward_index.back();
}
bool BuildInvertedIndex(const DocInfo &doc)
{
//DocInfo(title,content,url,doc_id)
//world -> 倒排拉链
struct word_cnt{
int title_cnt;
int content_cnt;
word_cnt():title_cnt(0),content_cnt(0){}
};
std::unordered_map<std::string,word_cnt> word_map;//用来暂存词频的映射表
//对标题进行分词
std::vector<std::string> title_words;
ns_util::JiebaUtil::CutString2(doc.title,&title_words);//调用了CutString2
//对标题进行词频统计
for(auto &s : title_words){
boost::to_lower(s);
word_map[s].title_cnt++;
}
//对文档内容进行分词
std::vector<std::string> content_words;
ns_util::JiebaUtil::CutString2(doc.content,&content_words);
//对内容进行词频统计
for(auto &s : content_words){
boost::to_lower(s);
word_map[s].content_cnt++;
}
#define X 10
#define Y 1
//Hello.HELLO.hello(倒排索引的大小写要忽略)
//根据文档内容,形成一个或者多个InvertedElem(倒排拉链)
//因为当前我们是一个一个文档进行处理的,一个文档会包含多个“词”,都应当对应到当前的doc_id
for(auto &word_pair : word_map){
InvertedElem item;
item.doc_id = doc.doc_id;
item.word = word_pair.first;
item.weight = X*word_pair.second.title_cnt + Y*word_pair.second.content_cnt;//相关性
InvertedList &inverted_list = inverted_index[word_pair.first];
inverted_list.push_back(std::move(item));
}
//1.需要对title && content都要先分词
//title: 吃/葡萄
//content:吃/葡萄/不吐/葡萄皮
//词和文档的相关性(非常复杂,我们采用词频:在标题中出现的词,可以认为相关性更高一些,在内容中出现相关性低一些)
//2.词频统计
//知道了在文档中,标题和内容每个词出现的次数
//3. 自定义相关性
//jieba的使用————cppjieba
return true;
}
};
Index* Index::instance = nullptr;
std::mutex Index::mtx;
}
log.hpp
#pragma once
#include <iostream>
#include <string>
#include <ctime>
#define NORMAL 1
#define WARNING 2
#define DEBUG 3
#define FATAL 4
#define LOG2(LEVEL,MESSAGE) log(#LEVEL,MESSAGE,__FILE__,__LINE__)
//@brief:时间戳转日期时间
static inline std::string getDateTimeFromTS(time_t ts) {
if(ts<0) {
return "";
}
struct tm tm = *localtime(&ts);
static char time_str[32]{0};
snprintf(time_str,sizeof(time_str),"%04d-%02d-%02d %02d:%02d:%02d",tm.tm_year+1900,tm.tm_mon+1,tm.tm_mday,tm.tm_hour,tm.tm_min,tm.tm_sec);
return std::string(time_str);
}
void log(std::string level,std::string message,std::string file,int line)
{
std::cout<<"["<<level<<"]"<<"["<<getDateTimeFromTS(time(nullptr))<<"]"<<"["<<message<<"]"<<"["<<file<<":"<<line<<"]"<<std::endl;
}
parser.cc(用于对网页的html文件切分且存储索引关系)
#include <iostream>
#include <string>
#include <vector>
#include <boost/filesystem.hpp>
#include "util.hpp"
#include "log.hpp"
const std::string src_path = "data/input";
const std::string output = "data/raw_html/raw.txt";//结尾没有'/'
typedef struct DocInfo{
std::string title;//文档的标题
std::string content;//文档内容
std::string url;//该文档在官网中的url
}DocInfo_t;
//const & 输入
//* 输出
//& 输入输出
bool EnumFile(const std::string &src_path,std::vector<std::string> *file_list);
bool ParseHtml(const std::vector<std::string> &files_list,std::vector<DocInfo_t> *results);
bool SaveHtml(const std::vector<DocInfo_t> &results,const std::string &output);
int main()
{
std::vector<std::string> files_list;
//第一步,递归式的把每个html文件名带路径,保存到files_list中,方便后期进行一个一个的文件进行读取
if(!EnumFile(src_path, &files_list))
{
//std::cerr<<"enum file error!" <<std::endl;
LOG2(FATAL,"enum file error!");
return 1;
}
//第二步,按照files_list读取每个文件的内容,并进行解析
std::vector<DocInfo_t> results;
if(!ParseHtml(files_list,&results))
{
//std::cerr <<"parse html error"<<std::endl;
LOG2(FATAL,"parse html error");
return 2;
}
//第三步,把解析完毕的各个文件的内容,写入到output中,按照\3作为每个文档的分割符
if(!SaveHtml(results,output))
{
//std::cerr<<"save html error"<<std::endl;
LOG2(FATAL,"save html error");
return 3;
}
return 0;
}
bool EnumFile(const std::string &src_path,std::vector<std::string> *files_list)
{
namespace fs = boost::filesystem;
fs::path root_path(src_path);
//判断路径是否存在,不存在就没必要往后走了
if(!fs::exists(root_path))
{
//std::cerr<< src_path<<"not exists"<<std::endl;
LOG2(FATAL,"src_path not exists");
return false;
}
//定义一个空的迭代器,用来进行判断递归结束
fs::recursive_directory_iterator end;
for(fs::recursive_directory_iterator iter(root_path);iter != end;iter++){
//判断文件是否是普通文件(html是普通文件)
if(!fs::is_regular_file(*iter))
{
continue;
}
if(iter->path().extension()!= ".html"){//判断文件路径名的后缀是否符合要求 path()提取路径字符串,是一个路径对象 ,extension()提取后缀(.以及之后的部分)
continue;
}
//std::cout<<"debug: " <<iter->path().string()<<std::endl;
//当前的路径一定是一个合法的,以.html结束的普通网页文件、
files_list->push_back(iter->path().string());//将所有带路径的html保存到files_list,方便后续进行文本分析
}
return true;
}
static bool ParseTitle(const std::string &file,std::string *title){
std::size_t begin = file.find("<title>");
if(begin == std::string::npos){
return false;
}
std::size_t end = file.find("</title>");
if(end==std::string::npos)
{
return false;
}
begin+=std::string("<title>").size();
if(begin>end){
return false;
}
*title = file.substr(begin,end-begin);
return true;
}
static bool ParseContent(const std::string &file,std::string *content){
//去标签,基于一个简易的状态机编写
enum status{
LABLE,
CONTENT
};
enum status s=LABLE;
for(char c :file){
switch(s)
{
case LABLE:
if(c=='>') s= CONTENT;
break;
case CONTENT:
if(c=='<') s= LABLE;
else
{
//我们不想要保留原始文件中的‘\n’,因为我们想用\n作为html解析之后文本的分隔符
if(c=='\n')c=' ';
content->push_back(c);
}
break;
default:
break;
}
}
return true;
}
static bool ParseUrl(const std::string &file_path,std::string *url)
{
std::string url_head = "https://www.boost.org/doc/libs/1_78_0/doc/html";
std::string url_tail = file_path.substr(src_path.size());//越过长度截取
*url = url_head + url_tail;
return true;
}
//for debug
static void ShowDoc(const DocInfo_t &doc)
{
std::cout<<"title:"<<doc.title << std::endl;
std::cout<<"content:"<<doc.content << std::endl;
std::cout<<"url:"<<doc.url << std::endl;
}
bool ParseHtml(const std::vector<std::string> &files_list,std::vector<DocInfo_t> *results)
{
for(const std::string &file : files_list)
{
//1.读取文件,Read()
std::string result;
if(!ns_util::FileUtil::ReadFile(file,&result)){
continue;
}
//2.解析指定的文件,提取title
DocInfo_t doc;
if(!ParseTitle(result,&doc.title)){
continue;
}
//3.解析指定的文件,提取content
if(!ParseContent(result,&doc.content)){
continue;
}
//4.解析指定的文件路径,构建url
if(!ParseUrl(file,&doc.url)){
continue;
}
//done,一定是完成了解析任务,当前文档的相关结果都保存在doc中
results->push_back(std::move(doc)); //bug to do细节,本质会发生拷贝,效率可能会比较低 (move是细节)
//std::cout<<1<<std::endl;
//for debug
//ShowDoc(doc);
//break;
}
return true;
}
bool SaveHtml(const std::vector<DocInfo_t> &results,const std::string &output)
{
#define SEP '\3'
//按照二进制方式进行写入
std::ofstream out(output,std::ios::out | std::ios::binary);
if(!out.is_open()){
//std::cerr<<"open "<<output <<"failed!"<<std::endl;
LOG2(FATAL,"open output failed!");
return false;
}
//就可以进行文件内容的写入了
for(auto &item : results)
{
std::string out_string;
out_string = item.title;
out_string+=SEP;
out_string +=item.content;
out_string +=SEP;
out_string +=item.url;
out_string+='\n';
out.write(out_string.c_str(),out_string.size());
}
out.close();
return true;
}
//strstr 前闭后开
searcher.hpp
#pragma once
#include "index.hpp"
#include "util.hpp"
#include <algorithm>
#include <jsoncpp/json/json.h>
#include "log.hpp"
//#include <vector>
namespace ns_searcher{
struct InvertedElemPrint{
uint64_t doc_id;
int weight;
std::vector<std::string> words;
InvertedElemPrint():doc_id(0),weight(0){}
};
class Searcher{
private:
ns_index::Index *index;
public:
Searcher(){}
~Searcher(){}
public:
void InitSearcher(const std::string &input)
{
//1. 获取或者创建index对象
index = ns_index::Index::GetInstance();
//std::cout <<"获取index单例成功..."<<std::endl;
LOG2(NORMAL,"获取index单例成功...");
//2. 根据index对象建立索引
index->BuildIndex(input);//CutString
//std::cout<<"建立正排和倒排索引成功..."<<std::endl;
LOG2(NORMAL,"建立正排和倒排索引成功...");
}
//query:搜索关键字
//json_string:返回给用户浏览器的搜索结果
void Search(const std::string &query,std::string *json_string)
{
//1. [分词]:对我们的query进行按照searcher的要求进行分词
std::vector<std::string> words;
ns_util::JiebaUtil::CutString(query,&words);
//2. [触发]:就是根据分词的各个“词,进行index查找”,建立index是忽略大小写,所以搜索关键字也需要
//ns_index::InvertedList inverted_list_all;
std::vector<InvertedElemPrint> inverted_list_all;
std::unordered_map<uint64_t,InvertedElemPrint> tokens_map;
for(std::string word : words)
{
boost::to_lower(word);
ns_index::InvertedList *inverted_list = index->GetInvertedList(word);
if(nullptr == inverted_list)
{
continue;
}
//不完美的地方(去重)
//inverted_list_all.insert(inverted_list_all.end(),inverted_list->begin(),inverted_list->end());
for(const auto &elem : *inverted_list)
{
auto &item = tokens_map[elem.doc_id];
//item一定是doc_id相同的print节点
item.doc_id =elem.doc_id;
item.weight += elem.weight;
item.words.push_back(elem.word);
}
}
for(const auto&item : tokens_map){
inverted_list_all.push_back(std::move(item.second));
}
//3. [合并排序]:汇总查找结果,按照相关性(weight)降序排序
/*std::sort(inverted_list_all.begin(),inverted_list_all.end(),\
[](const ns_index::InvertedElem &e1,const ns_index::InvertedElem &e2){
return e1.weight>e2.weight;
}
);
*/
std::sort(inverted_list_all.begin(),inverted_list_all.end(),\
[](const InvertedElemPrint&e1,const InvertedElemPrint& e2){
return e1.weight >e2.weight;
});
//4. [构建]:根据查找出来的结果,构建json串————jsoncpp----通过jsoncpp完成序列化和反序列化
Json::Value root;
for(auto &item : inverted_list_all){
ns_index::DocInfo *doc = index->GetForwardIndex(item.doc_id);
if(nullptr == doc)
{
continue;
}
Json::Value elem;
elem["title"] = doc->title;
elem["desc"] = GetDesc(doc->content,item.words[0]); //content是文档的去标签的结果,但是不是我们想要的,我们要的是一部分
elem["url"] = doc->url;
//foe debug
//elem["id"]= (int)item.doc_id;//doc_id是64位的uint64_t
//elem["weight"] = item.weight;
root.append(elem);
}
//Json::StyledWriter writer;
Json::FastWriter writer;
*json_string = writer.write(root);
}
std::string GetDesc(const std::string &html_content,const std::string &word)
{
//找到word在html_content中的首次出现,然后往前找50个字节(如果没有,从begin开始),往后找100个字节(如果没有,到end就可以),截取出这部分内容
const std::size_t prev_step = 50;
const std::size_t next_step =100;
//1. 找到首次出现
auto iter = std::search(html_content.begin(),html_content.end(),word.begin(),word.end(),[](int x,int y){
return (std::tolower(x)==std::tolower(y));
});
if(iter == html_content.end())
{
return "None1";
}
std::size_t pos = std::distance(html_content.begin(),iter);
/*std::size_t pos = html_content.find(word);
if(pos == std::string::npos){
return "None1";//这种情况是不存在的
}*/
//2. 获取start,end //这里有一个大坑,就是std::size_t是一个无符号数,无符号数相减为正数
std::size_t start = 0;
std::size_t end = html_content.size() - 1;
//如果之前有50个字符,就更新开始位置
if(pos >start+ prev_step) start = pos -prev_step;//换成加法
if(pos + next_step <end) end = pos + next_step;
//3. 截取子串,return
if(start >= end)return "None2";
std::string desc = html_content.substr(start,end-start+1);
std::string result="..." + desc + "...";
return result;
}
};
}
util.hpp
#pragma once
#include <iostream>
#include <string>
#include <fstream>
#include <vector>
#include "boost_1_84_0/boost/algorithm/string.hpp"
#include "../cppjieba/include/cppjieba/Jieba.hpp"
//#include "cppjieba/jieba"
#include "log.hpp"
#include <mutex>
#include <unordered_map>
namespace ns_util{
class FileUtil
{
public:
static bool ReadFile(const std::string &file_path,std::string *out)
{
std::ifstream in(file_path,std::ios::in);
if(!in.is_open())
{
//std::cerr << "open_file" << file_path <<"error" <<std::endl;
LOG2(FATAL,"open_file error");
return false;
}
std::string line;
while(std::getline(in,line)){//如何理解getline读取到文件结束呢??getline到返回值是一个&,while(bool),本质是因为重载了强制类型转换
*out += line;
}
in.close();
return true;
}
};
class StringUtil{
public:
static void Split(const std::string&target,std::vector<std::string>*out,const std::string& sep)
{
//boost split
boost::split(*out,target,boost::is_any_of(sep),boost::token_compress_on);
}
};
const char* const DICT_PATH = "./dict/jieba.dict.utf8";
const char* const HMM_PATH = "./dict/hmm_model.utf8";
const char* const USER_DICT_PATH = "./dict/user.dict.utf8";
const char* const IDF_PATH = "./dict/idf.utf8";
const char* const STOP_WORD_PATH = "./dict/stop_words.utf8";
class JiebaUtil
{
private:
//static cppjieba::Jieba jieba;
cppjieba::Jieba jieba;
std::unordered_map<std::string,bool> stop_words;
private:
JiebaUtil():jieba(DICT_PATH,HMM_PATH,USER_DICT_PATH,IDF_PATH,STOP_WORD_PATH)
{}
JiebaUtil(const JiebaUtil&)=delete;
JiebaUtil& operator=(JiebaUtil const&)=delete;
static JiebaUtil *instance;
public:
static JiebaUtil*get_instance()
{
static std::mutex mtx;
if(nullptr==instance){
mtx.lock();
if(nullptr ==instance){
instance = new JiebaUtil();
instance->InitJiebaUtil();
}
mtx.unlock();
}
return instance;
}
void InitJiebaUtil()
{
std::ifstream in(STOP_WORD_PATH);
if(!in.is_open())
{
LOG2(FATAL,"load stop words fill error");
return ;
}
std::string line;
while(std::getline(in,line))
{
stop_words.insert({line,true});
}
in.close();
}
void CutStringHelper(const std::string &src,std::vector<std::string>*out)
{
jieba.CutForSearch(src,*out);
std::vector<std::string> v(*out);
// //for debug
// for(auto e : v)
// {
// std::cout<<"暂停词测试存储 v:"<<e<<"----"<<std::endl;
// }
for(auto iter=out->begin();iter!=out->end();){
auto it =stop_words.find(*iter);
if(it!=stop_words.end())
{
//说明当前的string是暂停词,需要去掉
iter = out->erase(iter);
}
else
{
iter++;
}
}
if(out->empty())
{
//std::cout<< "out为空"<<std::endl;
*out = v;
}
//debug
// std::cout<< out->empty()<<std::endl;
// for(auto e : *out)
// {
// std::cout<<"暂停词测试out 后:"<<e<<"----"<<std::endl;
// }
}
void CutString_has_stop_words(const std::string &src,std::vector<std::string>*out)
{
jieba.CutForSearch(src,*out);
}
public:
static void CutString(const std::string &src,std::vector<std::string> *out)
{
//debug
//std::cout<< "CutStringHelper" << std::endl;
ns_util::JiebaUtil::get_instance()->CutStringHelper(src,out);
//jieba.CutForSearch(src,*out);
}
static void CutString2(const std::string &src,std::vector<std::string> *out)
{
//debug
//std::cout<< "CutString2()" << std::endl;
ns_util::JiebaUtil::get_instance()->CutString_has_stop_words(src,out);
}
//cppjieba::Jieba JiebaUtil::jieba(DICT_PATH,HMM_PATH,USER_DICT_PATH,IDF_PATH,STOP_WORD_PATH);
};
JiebaUtil *JiebaUtil::instance = nullptr;
//加static是因为这个函数要被外部使用,加了static可以不创建对象就可以使用
}
http_server.cc(用于启动服务器和搜索引擎)
#include "searcher.hpp"
#include "httplib.h"
#include "../http.hpp"
const std::string root_path = "./wwwroot";
const std::string input = "data/raw_html/raw.txt";
ns_searcher::Searcher search;
std::string RequestStr(const HttpRequest &req)
{
std::stringstream ss;
ss << req._method << " " << req._path << " " << req._version << "\r\n";
for (auto it : req._params)
{
ss << it.first << ": " << it.second << "\r\n";
DBG_LOG("RequestStr_params: first:%s ,second:%s", it.first, it.second);
}
for (auto it : req._headers)
{
ss << it.first << ": " << it.second << "\r\n";
DBG_LOG("RequestStr_headers: first:%s ,second:%s", it.first.c_str(), it.second.c_str());
}
ss << "\r\n";
ss << req._body;
return ss.str();
}
void Hello(const HttpRequest &req, HttpResponse *rsp)
{
if (!req.HasParam("word"))
{
rsp->SetContent("必须要有搜索关键字!", "text/plain; charset=utf-8");
return;
}
// rsp.set_content("hello world!你好世界\n","text/plain; charset=utf-8");
const std::string word = req.GetParam("word"); // 获取名为word的参数值
// debug
// std::cout<<"test:"<<word<<std::endl;
// std::cout<<"用户在搜索:"<<word<<std::endl;
//LOG2(NORMAL, "用户搜索的:" + word);
std::string json_string;
search.Search(word, &json_string);
rsp->SetContent(json_string, "application/json");
//rsp->SetContent(RequestStr(req), "text/plain");
}
void Login(const HttpRequest &req, HttpResponse *rsp)
{
rsp->SetContent(RequestStr(req),"text/plain");
}
void PutFile(const HttpRequest &req, HttpResponse *rsp)
{
rsp->SetContent(RequestStr(req), "text/plain");
}
void DelFile(const HttpRequest &req, HttpResponse *rsp)
{
rsp->SetContent(RequestStr(req), "text/plain");
}
int main()
{
search.InitSearcher(input);
HttpServer server(8085);
server.SetThreadCount(3);
server.SetBaseDir(root_path); // 设置静态资源根目录,告诉服务器有静态资源请求到来,需要到哪里去找资源路径
server.Get("/s", Hello);
server.Post("/login", Login);
server.Put("/1234.txt", PutFile);
server.Delete("/1234.txt", DelFile);
server.Listen();
return 0;
}
高并发服务器部分
http.hpp
#include <iostream>
#include <fstream>
#include <string>
#include <vector>
#include <regex>
#include <sys/stat.h>
#include "../server.hpp"
#define DEFALT_TIMEOUT 10
std::unordered_map<int, std::string> _statu_msg = {
{100, "Continue"},
{101, "Switching Protocol"},
{102, "Processing"},
{103, "Early Hints"},
{200, "OK"},
{201, "Created"},
{202, "Accepted"},
{203, "Non-Authoritative Information"},
{204, "No Content"},
{205, "Reset Content"},
{206, "Partial Content"},
{207, "Multi-Status"},
{208, "Already Reported"},
{226, "IM Used"},
{300, "Multiple Choice"},
{301, "Moved Permanently"},
{302, "Found"},
{303, "See Other"},
{304, "Not Modified"},
{305, "Use Proxy"},
{306, "unused"},
{307, "Temporary Redirect"},
{308, "Permanent Redirect"},
{400, "Bad Request"},
{401, "Unauthorized"},
{402, "Payment Required"},
{403, "Forbidden"},
{404, "Not Found"},
{405, "Method Not Allowed"},
{406, "Not Acceptable"},
{407, "Proxy Authentication Required"},
{408, "Request Timeout"},
{409, "Conflict"},
{410, "Gone"},
{411, "Length Required"},
{412, "Precondition Failed"},
{413, "Payload Too Large"},
{414, "URI Too Long"},
{415, "Unsupported Media Type"},
{416, "Range Not Satisfiable"},
{417, "Expectation Failed"},
{418, "I'm a teapot"},
{421, "Misdirected Request"},
{422, "Unprocessable Entity"},
{423, "Locked"},
{424, "Failed Dependency"},
{425, "Too Early"},
{426, "Upgrade Required"},
{428, "Precondition Required"},
{429, "Too Many Requests"},
{431, "Request Header Fields Too Large"},
{451, "Unavailable For Legal Reasons"},
{501, "Not Implemented"},
{502, "Bad Gateway"},
{503, "Service Unavailable"},
{504, "Gateway Timeout"},
{505, "HTTP Version Not Supported"},
{506, "Variant Also Negotiates"},
{507, "Insufficient Storage"},
{508, "Loop Detected"},
{510, "Not Extended"},
{511, "Network Authentication Required"}
};
std::unordered_map<std::string, std::string> _mime_msg = {
{".aac", "audio/aac"},
{".abw", "application/x-abiword"},
{".arc", "application/x-freearc"},
{".avi", "video/x-msvideo"},
{".azw", "application/vnd.amazon.ebook"},
{".bin", "application/octet-stream"},
{".bmp", "image/bmp"},
{".bz", "application/x-bzip"},
{".bz2", "application/x-bzip2"},
{".csh", "application/x-csh"},
{".css", "text/css"},
{".csv", "text/csv"},
{".doc", "application/msword"},
{".docx", "application/vnd.openxmlformats-officedocument.wordprocessingml.document"},
{".eot", "application/vnd.ms-fontobject"},
{".epub", "application/epub+zip"},
{".gif", "image/gif"},
{".htm", "text/html"},
{".html", "text/html"},
{".ico", "image/vnd.microsoft.icon"},
{".ics", "text/calendar"},
{".jar", "application/java-archive"},
{".jpeg", "image/jpeg"},
{".jpg", "image/jpeg"},
{".js", "text/javascript"},
{".json", "application/json"},
{".jsonld", "application/ld+json"},
{".mid", "audio/midi"},
{".midi", "audio/x-midi"},
{".mjs", "text/javascript"},
{".mp3", "audio/mpeg"},
{".mpeg", "video/mpeg"},
{".mpkg", "application/vnd.apple.installer+xml"},
{".odp", "application/vnd.oasis.opendocument.presentation"},
{".ods", "application/vnd.oasis.opendocument.spreadsheet"},
{".odt", "application/vnd.oasis.opendocument.text"},
{".oga", "audio/ogg"},
{".ogv", "video/ogg"},
{".ogx", "application/ogg"},
{".otf", "font/otf"},
{".png", "image/png"},
{".pdf", "application/pdf"},
{".ppt", "application/vnd.ms-powerpoint"},
{".pptx", "application/vnd.openxmlformats-officedocument.presentationml.presentation"},
{".rar", "application/x-rar-compressed"},
{".rtf", "application/rtf"},
{".sh", "application/x-sh"},
{".svg", "image/svg+xml"},
{".swf", "application/x-shockwave-flash"},
{".tar", "application/x-tar"},
{".tif", "image/tiff"},
{".tiff", "image/tiff"},
{".ttf", "font/ttf"},
{".txt", "text/plain"},
{".vsd", "application/vnd.visio"},
{".wav", "audio/wav"},
{".weba", "audio/webm"},
{".webm", "video/webm"},
{".webp", "image/webp"},
{".woff", "font/woff"},
{".woff2", "font/woff2"},
{".xhtml", "application/xhtml+xml"},
{".xls", "application/vnd.ms-excel"},
{".xlsx", "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"},
{".xml", "application/xml"},
{".xul", "application/vnd.mozilla.xul+xml"},
{".zip", "application/zip"},
{".3gp", "video/3gpp"},
{".3g2", "video/3gpp2"},
{".7z", "application/x-7z-compressed"}
};
class Util {
public:
//字符串分割函数,将src字符串按照sep字符进行分割,得到的各个字串放到arry中,最终返回字串的数量
static size_t Split(const std::string &src, const std::string &sep, std::vector<std::string> *arry) {
size_t offset = 0;
// 有10个字符,offset是查找的起始位置,范围应该是0~9,offset==10就代表已经越界了
while(offset < src.size()) {
size_t pos = src.find(sep, offset);//在src字符串偏移量offset处,开始向后查找sep字符/字串,返回查找到的位置
if (pos == std::string::npos) {//没有找到特定的字符
//将剩余的部分当作一个字串,放入arry中
if(pos == src.size()) break;
arry->push_back(src.substr(offset));
return arry->size();
}
if (pos == offset) {
offset = pos + sep.size();
continue;//当前字串是一个空的,没有内容
}
arry->push_back(src.substr(offset, pos - offset));
offset = pos + sep.size();
}
return arry->size();
}
//读取文件的所有内容,将读取的内容放到一个Buffer中
static bool ReadFile(const std::string &filename, std::string *buf) {
std::ifstream ifs(filename, std::ios::binary);
if (ifs.is_open() == false) {
printf("OPEN %s FILE FAILED!!", filename.c_str());
return false;
}
size_t fsize = 0;
ifs.seekg(0, ifs.end);//跳转读写位置到末尾
fsize = ifs.tellg(); //获取当前读写位置相对于起始位置的偏移量,从末尾偏移刚好就是文件大小
ifs.seekg(0, ifs.beg);//跳转到起始位置
buf->resize(fsize); //开辟文件大小的空间
ifs.read(&(*buf)[0], fsize);
if (ifs.good() == false) {
printf("READ %s FILE FAILED!!", filename.c_str());
ifs.close();
return false;
}
ifs.close();
return true;
}
//向文件写入数据
static bool WriteFile(const std::string &filename, const std::string &buf) {
std::ofstream ofs(filename, std::ios::binary | std::ios::trunc);
if (ofs.is_open() == false) {
printf("OPEN %s FILE FAILED!!", filename.c_str());
return false;
}
ofs.write(buf.c_str(), buf.size());
if (ofs.good() == false) {
ERR_LOG("WRITE %s FILE FAILED!", filename.c_str());
ofs.close();
return false;
}
ofs.close();
return true;
}
//URL编码,避免URL中资源路径与查询字符串中的特殊字符与HTTP请求中特殊字符产生歧义
//编码格式:将特殊字符的ascii值,转换为两个16进制字符,前缀% C++ -> C%2B%2B
// 不编码的特殊字符: RFC3986文档规定 . - _ ~ 字母,数字属于绝对不编码字符
//RFC3986文档规定,编码格式 %HH
//W3C标准中规定,查询字符串中的空格,需要编码为+, 解码则是+转空格
static std::string UrlEncode(const std::string url, bool convert_space_to_plus) {
std::string res;
for (auto &c : url) {
if (c == '.' || c == '-' || c == '_' || c == '~' || isalnum(c)) {
res += c;
continue;
}
if (c == ' ' && convert_space_to_plus == true) {
res += '+';
continue;
}
//剩下的字符都是需要编码成为 %HH 格式
char tmp[4] = {0};
//snprintf 与 printf比较类似,都是格式化字符串,只不过一个是打印,一个是放到一块空间中
snprintf(tmp, 4, "%%%02X", c);
res += tmp;
}
return res;
}
static char HEXTOI(char c) {
if (c >= '0' && c <= '9') {
return c - '0';
}else if (c >= 'a' && c <= 'z') {
return c - 'a' + 10;
}else if (c >= 'A' && c <= 'Z') {
return c - 'A' + 10;
}
return -1;
}
static std::string UrlDecode(const std::string url, bool convert_plus_to_space) {
//遇到了%,则将紧随其后的2个字符,转换为数字,第一个数字左移4位,然后加上第二个数字 + -> 2b %2b->2 << 4 + 11
std::string res;
for (int i = 0; i < url.size(); i++) {
if (url[i] == '+' && convert_plus_to_space == true) {
res += ' ';
continue;
}
if (url[i] == '%' && (i + 2) < url.size()) {
char v1 = HEXTOI(url[i + 1]);
char v2 = HEXTOI(url[i + 2]);
char v = v1 * 16 + v2;
res += v;
i += 2;
continue;
}
res += url[i];
}
return res;
}
//响应状态码的描述信息获取
static std::string StatuDesc(int statu) {
auto it = _statu_msg.find(statu);
if (it != _statu_msg.end()) {
return it->second;
}
return "Unknow";
}
//根据文件后缀名获取文件mime
static std::string ExtMime(const std::string &filename) {
// a.b.txt 先获取文件扩展名
size_t pos = filename.find_last_of('.');
if (pos == std::string::npos) {
return "application/octet-stream";
}
//根据扩展名,获取mime
std::string ext = filename.substr(pos);
auto it = _mime_msg.find(ext);
if (it == _mime_msg.end()) {
return "application/octet-stream";
}
return it->second;
}
//判断一个文件是否是一个目录
static bool IsDirectory(const std::string &filename) {
struct stat st;
int ret = stat(filename.c_str(), &st);
if (ret < 0) {
return false;
}
return S_ISDIR(st.st_mode);
}
//判断一个文件是否是一个普通文件
static bool IsRegular(const std::string &filename) {
struct stat st;
int ret = stat(filename.c_str(), &st);
if (ret < 0) {
return false;
}
return S_ISREG(st.st_mode);
}
//http请求的资源路径有效性判断
// /index.html --- 前边的/叫做相对根目录 映射的是某个服务器上的子目录
// 想表达的意思就是,客户端只能请求相对根目录中的资源,其他地方的资源都不予理会
// /../login, 这个路径中的..会让路径的查找跑到相对根目录之外,这是不合理的,不安全的
static bool ValidPath(const std::string &path) {
//思想:按照/进行路径分割,根据有多少子目录,计算目录深度,有多少层,深度不能小于0
std::vector<std::string> subdir;
Split(path, "/", &subdir);
int level = 0;
for (auto &dir : subdir) {
if (dir == "..") {
level--; //任意一层走出相对根目录,就认为有问题
if (level < 0) return false;
continue;
}
level++;
}
return true;
}
};
class HttpRequest {
public:
std::string _method; //请求方法
std::string _path; //资源路径
std::string _version; //协议版本
std::string _body; //请求正文
std::smatch _matches; //资源路径的正则提取数据
std::unordered_map<std::string, std::string> _headers; //头部字段
std::unordered_map<std::string, std::string> _params; //查询字符串
public:
HttpRequest():_version("HTTP/1.1") {}
void ReSet() {
_method.clear();
_path.clear();
_version = "HTTP/1.1";
_body.clear();
std::smatch match;
_matches.swap(match);
_headers.clear();
_params.clear();
}
//插入头部字段
void SetHeader(const std::string &key, const std::string &val) {
_headers.insert(std::make_pair(key, val));
}
//判断是否存在指定头部字段
bool HasHeader(const std::string &key) const {
auto it = _headers.find(key);
if (it == _headers.end()) {
return false;
}
return true;
}
//获取指定头部字段的值
std::string GetHeader(const std::string &key) const {
auto it = _headers.find(key);
if (it == _headers.end()) {
return "";
}
return it->second;
}
//插入查询字符串
void SetParam(const std::string &key, const std::string &val) {
_params.insert(std::make_pair(key, val));
}
//判断是否有某个指定的查询字符串
bool HasParam(const std::string &key) const {
auto it = _params.find(key);
if (it == _params.end()) {
return false;
}
return true;
}
//获取指定的查询字符串
std::string GetParam(const std::string &key) const {
auto it = _params.find(key);
if (it == _params.end()) {
return "";
}
return it->second;
}
//获取正文长度
size_t ContentLength() const {
// Content-Length: 1234\r\n
bool ret = HasHeader("Content-Length");
if (ret == false) {
return 0;
}
std::string clen = GetHeader("Content-Length");
return std::stol(clen);
}
//判断是否是短链接
bool Close() const {
// 没有Connection字段,或者有Connection但是值是close,则都是短链接,否则就是长连接
if (HasHeader("Connection") == true && GetHeader("Connection") == "keep-alive") {
return false;
}
return true;
}
};
class HttpResponse {
public:
int _statu;
bool _redirect_flag;
std::string _body;
std::string _redirect_url;
std::unordered_map<std::string, std::string> _headers;
public:
HttpResponse():_redirect_flag(false), _statu(200) {}
HttpResponse(int statu):_redirect_flag(false), _statu(statu) {}
void ReSet() {
_statu = 200;
_redirect_flag = false;
_body.clear();
_redirect_url.clear();
_headers.clear();
}
//插入头部字段
void SetHeader(const std::string &key, const std::string &val) {
_headers.insert(std::make_pair(key, val));
}
//判断是否存在指定头部字段
bool HasHeader(const std::string &key) {
auto it = _headers.find(key);
if (it == _headers.end()) {
return false;
}
return true;
}
//获取指定头部字段的值
std::string GetHeader(const std::string &key) {
auto it = _headers.find(key);
if (it == _headers.end()) {
return "";
}
return it->second;
}
void SetContent(const std::string &body, const std::string &type = "text/html") {
_body = body;
SetHeader("Content-Type", type);
}
void SetRedirect(const std::string &url, int statu = 302) {
_statu = statu;
_redirect_flag = true;
_redirect_url = url;
}
//判断是否是短链接
bool Close() {
// 没有Connection字段,或者有Connection但是值是close,则都是短链接,否则就是长连接
if (HasHeader("Connection") == true && GetHeader("Connection") == "keep-alive") {
return false;
}
return true;
}
};
typedef enum {
RECV_HTTP_ERROR,
RECV_HTTP_LINE,
RECV_HTTP_HEAD,
RECV_HTTP_BODY,
RECV_HTTP_OVER
}HttpRecvStatu;
#define MAX_LINE 8192
class HttpContext {
private:
int _resp_statu; //响应状态码
HttpRecvStatu _recv_statu; //当前接收及解析的阶段状态
HttpRequest _request; //已经解析得到的请求信息
private:
bool ParseHttpLine(const std::string &line) {
std::smatch matches;
std::regex e("(GET|HEAD|POST|PUT|DELETE) ([^?]*)(?:\\?(.*))? (HTTP/1\\.[01])(?:\n|\r\n)?", std::regex::icase);
bool ret = std::regex_match(line, matches, e);
if (ret == false) {
_recv_statu = RECV_HTTP_ERROR;
_resp_statu = 400;//BAD REQUEST
return false;
}
//0 : GET /bitejiuyeke/login?user=xiaoming&pass=123123 HTTP/1.1
//1 : GET
//2 : /bitejiuyeke/login
//3 : user=xiaoming&pass=123123
//4 : HTTP/1.1
//请求方法的获取
_request._method = matches[1];
std::transform(_request._method.begin(), _request._method.end(), _request._method.begin(), ::toupper);
//资源路径的获取,需要进行URL解码操作,但是不需要+转空格
_request._path = Util::UrlDecode(matches[2], false);
//协议版本的获取
_request._version = matches[4];
//查询字符串的获取与处理
std::vector<std::string> query_string_arry;
std::string query_string = matches[3];
//查询字符串的格式 key=val&key=val....., 先以 & 符号进行分割,得到各个字串
Util::Split(query_string, "&", &query_string_arry);
//针对各个字串,以 = 符号进行分割,得到key 和val, 得到之后也需要进行URL解码
for (auto &str : query_string_arry) {
size_t pos = str.find("=");
if (pos == std::string::npos) {
_recv_statu = RECV_HTTP_ERROR;
_resp_statu = 400;//BAD REQUEST
return false;
}
std::string key = Util::UrlDecode(str.substr(0, pos), true);
std::string val = Util::UrlDecode(str.substr(pos + 1), true);
_request.SetParam(key, val);
}
return true;
}
bool RecvHttpLine(Buffer *buf) {
if (_recv_statu != RECV_HTTP_LINE) return false;
//1. 获取一行数据,带有末尾的换行
std::string line = buf->GetLineAndPop();
//2. 需要考虑的一些要素:缓冲区中的数据不足一行, 获取的一行数据超大
if (line.size() == 0) {
//缓冲区中的数据不足一行,则需要判断缓冲区的可读数据长度,如果很长了都不足一行,这是有问题的
if (buf->ReadAbleSize() > MAX_LINE) {
_recv_statu = RECV_HTTP_ERROR;
_resp_statu = 414;//URI TOO LONG
return false;
}
//缓冲区中数据不足一行,但是也不多,就等等新数据的到来
return true;
}
if (line.size() > MAX_LINE) {
_recv_statu = RECV_HTTP_ERROR;
_resp_statu = 414;//URI TOO LONG
return false;
}
bool ret = ParseHttpLine(line);
if (ret == false) {
return false;
}
//首行处理完毕,进入头部获取阶段
_recv_statu = RECV_HTTP_HEAD;
return true;
}
bool RecvHttpHead(Buffer *buf) {
if (_recv_statu != RECV_HTTP_HEAD) return false;
//一行一行取出数据,直到遇到空行为止, 头部的格式 key: val\r\nkey: val\r\n....
while(1){
std::string line = buf->GetLineAndPop();
//2. 需要考虑的一些要素:缓冲区中的数据不足一行, 获取的一行数据超大
if (line.size() == 0) {
//缓冲区中的数据不足一行,则需要判断缓冲区的可读数据长度,如果很长了都不足一行,这是有问题的
if (buf->ReadAbleSize() > MAX_LINE) {
_recv_statu = RECV_HTTP_ERROR;
_resp_statu = 414;//URI TOO LONG
return false;
}
//缓冲区中数据不足一行,但是也不多,就等等新数据的到来
return true;
}
if (line.size() > MAX_LINE) {
_recv_statu = RECV_HTTP_ERROR;
_resp_statu = 414;//URI TOO LONG
return false;
}
if (line == "\n" || line == "\r\n") {
break;
}
bool ret = ParseHttpHead(line);
if (ret == false) {
return false;
}
}
//头部处理完毕,进入正文获取阶段
_recv_statu = RECV_HTTP_BODY;
return true;
}
bool ParseHttpHead(std::string &line) {
//key: val\r\nkey: val\r\n....
if (line.back() == '\n') line.pop_back();//末尾是换行则去掉换行字符
if (line.back() == '\r') line.pop_back();//末尾是回车则去掉回车字符
size_t pos = line.find(": ");
if (pos == std::string::npos) {
_recv_statu = RECV_HTTP_ERROR;
_resp_statu = 400;//
return false;
}
std::string key = line.substr(0, pos);
std::string val = line.substr(pos + 2);
_request.SetHeader(key, val);
return true;
}
bool RecvHttpBody(Buffer *buf) {
if (_recv_statu != RECV_HTTP_BODY) return false;
//1. 获取正文长度
size_t content_length = _request.ContentLength();
if (content_length == 0) {
//没有正文,则请求接收解析完毕
_recv_statu = RECV_HTTP_OVER;
return true;
}
//2. 当前已经接收了多少正文,其实就是往 _request._body 中放了多少数据了
size_t real_len = content_length - _request._body.size();//实际还需要接收的正文长度
//3. 接收正文放到body中,但是也要考虑当前缓冲区中的数据,是否是全部的正文
// 3.1 缓冲区中数据,包含了当前请求的所有正文,则取出所需的数据
if (buf->ReadAbleSize() >= real_len) {
_request._body.append(buf->ReadPosition(), real_len);
buf->MoveReadOffset(real_len);
_recv_statu = RECV_HTTP_OVER;
return true;
}
// 3.2 缓冲区中数据,无法满足当前正文的需要,数据不足,取出数据,然后等待新数据到来
_request._body.append(buf->ReadPosition(), buf->ReadAbleSize());
buf->MoveReadOffset(buf->ReadAbleSize());
return true;
}
public:
HttpContext():_resp_statu(200), _recv_statu(RECV_HTTP_LINE) {}
void ReSet() {
_resp_statu = 200;
_recv_statu = RECV_HTTP_LINE;
_request.ReSet();
}
int RespStatu() { return _resp_statu; }
HttpRecvStatu RecvStatu() { return _recv_statu; }
HttpRequest &Request() { return _request; }
//接收并解析HTTP请求
void RecvHttpRequest(Buffer *buf) {
//不同的状态,做不同的事情,但是这里不要break, 因为处理完请求行后,应该立即处理头部,而不是退出等新数据
switch(_recv_statu) {
case RECV_HTTP_LINE: RecvHttpLine(buf);
case RECV_HTTP_HEAD: RecvHttpHead(buf);
case RECV_HTTP_BODY: RecvHttpBody(buf);
}
return;
}
};
class HttpServer {
private:
using Handler = std::function<void(const HttpRequest &, HttpResponse *)>;
using Handlers = std::vector<std::pair<std::regex, Handler>>;
Handlers _get_route;
Handlers _post_route;
Handlers _put_route;
Handlers _delete_route;
std::string _basedir; //静态资源根目录
TcpServer _server;
private:
void ErrorHandler(const HttpRequest &req, HttpResponse *rsp) {
//1. 组织一个错误展示页面
std::string body;
body += "<html>";
body += "<head>";
body += "<meta http-equiv='Content-Type' content='text/html;charset=utf-8'>";
body += "</head>";
body += "<body>";
body += "<h1>";
body += std::to_string(rsp->_statu);
body += " ";
body += Util::StatuDesc(rsp->_statu);
body += "</h1>";
body += "</body>";
body += "</html>";
//2. 将页面数据,当作响应正文,放入rsp中
rsp->SetContent(body, "text/html");
}
//将HttpResponse中的要素按照http协议格式进行组织,发送
void WriteReponse(const PtrConnection &conn, const HttpRequest &req, HttpResponse &rsp) {
//1. 先完善头部字段
if (req.Close() == true) {
rsp.SetHeader("Connection", "close");
}else {
rsp.SetHeader("Connection", "keep-alive");
}
if (rsp._body.empty() == false && rsp.HasHeader("Content-Length") == false) {
rsp.SetHeader("Content-Length", std::to_string(rsp._body.size()));
}
if (rsp._body.empty() == false && rsp.HasHeader("Content-Type") == false) {
rsp.SetHeader("Content-Type", "application/octet-stream");
}
if (rsp._redirect_flag == true) {
rsp.SetHeader("Location", rsp._redirect_url);
}
//2. 将rsp中的要素,按照http协议格式进行组织
std::stringstream rsp_str;
rsp_str << req._version << " " << std::to_string(rsp._statu) << " " << Util::StatuDesc(rsp._statu) << "\r\n";
for (auto &head : rsp._headers) {
rsp_str << head.first << ": " << head.second << "\r\n";
}
rsp_str << "\r\n";
rsp_str << rsp._body;
//3. 发送数据
conn->Send(rsp_str.str().c_str(), rsp_str.str().size());
}
bool IsFileHandler(const HttpRequest &req) {
// 1. 必须设置了静态资源根目录
if (_basedir.empty()) {
return false;
}
// 2. 请求方法,必须是GET / HEAD请求方法
if (req._method != "GET" && req._method != "HEAD") {
return false;
}
// 3. 请求的资源路径必须是一个合法路径
if (Util::ValidPath(req._path) == false) {
return false;
}
// 4. 请求的资源必须存在,且是一个普通文件
// 有一种请求比较特殊 -- 目录:/, /image/, 这种情况给后边默认追加一个 index.html
// index.html /image/a.png
// 不要忘了前缀的相对根目录,也就是将请求路径转换为实际存在的路径 /image/a.png -> ./wwwroot/image/a.png
std::string req_path = _basedir + req._path;//为了避免直接修改请求的资源路径,因此定义一个临时对象
if (req._path.back() == '/') {
req_path += "index.html";
}
if (Util::IsRegular(req_path) == false) {
return false;
}
return true;
}
//静态资源的请求处理 --- 将静态资源文件的数据读取出来,放到rsp的_body中, 并设置mime
void FileHandler(const HttpRequest &req, HttpResponse *rsp) {
std::string req_path = _basedir + req._path;
if (req._path.back() == '/') {
req_path += "index.html";
}
bool ret = Util::ReadFile(req_path, &rsp->_body);
if (ret == false) {
return;
}
std::string mime = Util::ExtMime(req_path);
rsp->SetHeader("Content-Type", mime);
return;
}
//功能性请求的分类处理
void Dispatcher(HttpRequest &req, HttpResponse *rsp, Handlers &handlers) {
//在对应请求方法的路由表中,查找是否含有对应资源请求的处理函数,有则调用,没有则发挥404
//思想:路由表存储的时键值对 -- 正则表达式 & 处理函数
//使用正则表达式,对请求的资源路径进行正则匹配,匹配成功就使用对应函数进行处理
// /numbers/(\d+) /numbers/12345
for (auto &handler : handlers) {
const std::regex &re = handler.first;
const Handler &functor = handler.second;
bool ret = std::regex_match(req._path, req._matches, re);
if (ret == false) {
continue;
}
return functor(req, rsp);//传入请求信息,和空的rsp,执行处理函数
}
rsp->_statu = 404;
}
void Route(HttpRequest &req, HttpResponse *rsp) {
//1. 对请求进行分辨,是一个静态资源请求,还是一个功能性请求
// 静态资源请求,则进行静态资源的处理
// 功能性请求,则需要通过几个请求路由表来确定是否有处理函数
// 既不是静态资源请求,也没有设置对应的功能性请求处理函数,就返回405
if (IsFileHandler(req) == true) {
//是一个静态资源请求, 则进行静态资源请求的处理
return FileHandler(req, rsp);
}
if (req._method == "GET" || req._method == "HEAD") {
return Dispatcher(req, rsp, _get_route);
}else if (req._method == "POST") {
return Dispatcher(req, rsp, _post_route);
}else if (req._method == "PUT") {
return Dispatcher(req, rsp, _put_route);
}else if (req._method == "DELETE") {
return Dispatcher(req, rsp, _delete_route);
}
rsp->_statu = 405;// Method Not Allowed
return ;
}
//设置上下文
void OnConnected(const PtrConnection &conn) {
conn->SetContext(HttpContext());
DBG_LOG("NEW CONNECTION %p", conn.get());
}
//缓冲区数据解析+处理
void OnMessage(const PtrConnection &conn, Buffer *buffer) {
while(buffer->ReadAbleSize() > 0){
//1. 获取上下文
HttpContext *context = conn->GetContext()->get<HttpContext>();
//2. 通过上下文对缓冲区数据进行解析,得到HttpRequest对象
// 1. 如果缓冲区的数据解析出错,就直接回复出错响应
// 2. 如果解析正常,且请求已经获取完毕,才开始去进行处理
context->RecvHttpRequest(buffer);
HttpRequest &req = context->Request();
HttpResponse rsp(context->RespStatu());
if (context->RespStatu() >= 400) {
//进行错误响应,关闭连接
ErrorHandler(req, &rsp);//填充一个错误显示页面数据到rsp中
WriteReponse(conn, req, rsp);//组织响应发送给客户端
context->ReSet();
buffer->MoveReadOffset(buffer->ReadAbleSize());//出错了就把缓冲区数据清空
conn->Shutdown();//关闭连接
return;
}
if (context->RecvStatu() != RECV_HTTP_OVER) {
//当前请求还没有接收完整,则退出,等新数据到来再重新继续处理
return;
}
//3. 请求路由 + 业务处理
Route(req, &rsp);
//4. 对HttpResponse进行组织发送
WriteReponse(conn, req, rsp);
//5. 重置上下文
context->ReSet();
//6. 根据长短连接判断是否关闭连接或者继续处理
if (rsp.Close() == true) conn->Shutdown();//短链接则直接关闭
}
return;
}
public:
HttpServer(int port, int timeout = DEFALT_TIMEOUT):_server(port) {
_server.EnableInactiveRelease(timeout);
_server.SetConnectedCallback(std::bind(&HttpServer::OnConnected, this, std::placeholders::_1));
_server.SetMessageCallback(std::bind(&HttpServer::OnMessage, this, std::placeholders::_1, std::placeholders::_2));
}
void SetBaseDir(const std::string &path) {
assert(Util::IsDirectory(path) == true);
_basedir = path;
}
/*设置/添加,请求(请求的正则表达)与处理函数的映射关系*/
void Get(const std::string &pattern, const Handler &handler) {
_get_route.push_back(std::make_pair(std::regex(pattern), handler));
}
void Post(const std::string &pattern, const Handler &handler) {
_post_route.push_back(std::make_pair(std::regex(pattern), handler));
}
void Put(const std::string &pattern, const Handler &handler) {
_put_route.push_back(std::make_pair(std::regex(pattern), handler));
}
void Delete(const std::string &pattern, const Handler &handler) {
_delete_route.push_back(std::make_pair(std::regex(pattern), handler));
}
void SetThreadCount(int count) {
_server.SetThreadCount(count);
}
void Listen() {
_server.Start();
}
};
server.hpp
#ifndef __M_SERVER_H__
#define __M_SERVER_H__
#include <iostream>
#include <vector>
#include <string>
#include <cassert>
#include <cstring>
#include <ctime>
#include <functional>
#include <unordered_map>
#include <thread>
#include <mutex>
#include <condition_variable>
#include <memory>
#include <typeinfo>
#include <fcntl.h>
#include <signal.h>
#include <unistd.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <sys/socket.h>
#include <sys/epoll.h>
#include <sys/eventfd.h>
#include <sys/timerfd.h>
#define INF 0
#define DBG 1
#define ERR 2
#define LOG_LEVEL DBG
#define LOG(level, format, ...) do{\
if (level < LOG_LEVEL) break;\
time_t t = time(NULL);\
struct tm *ltm = localtime(&t);\
char tmp[32] = {0};\
strftime(tmp, 31, "%H:%M:%S", ltm);\
fprintf(stdout, "[%p %s %s:%d] " format "\n", (void*)pthread_self(), tmp, __FILE__, __LINE__, ##__VA_ARGS__);\
}while(0)
#define INF_LOG(format, ...) LOG(INF, format, ##__VA_ARGS__)
#define DBG_LOG(format, ...) LOG(DBG, format, ##__VA_ARGS__)
#define ERR_LOG(format, ...) LOG(ERR, format, ##__VA_ARGS__)
#define BUFFER_DEFAULT_SIZE 1024
class Buffer {
private:
std::vector<char> _buffer; //使用vector进行内存空间管理
uint64_t _reader_idx; //读偏移
uint64_t _writer_idx; //写偏移
public:
Buffer():_reader_idx(0), _writer_idx(0), _buffer(BUFFER_DEFAULT_SIZE){}
char *Begin() { return &*_buffer.begin(); }
//获取当前写入起始地址, _buffer的空间起始地址,加上写偏移量
char *WritePosition() { return Begin() + _writer_idx; }
//获取当前读取起始地址
char *ReadPosition() { return Begin() + _reader_idx; }
//获取缓冲区末尾空闲空间大小--写偏移之后的空闲空间, 总体空间大小减去写偏移
uint64_t TailIdleSize() { return _buffer.size() - _writer_idx; }
//获取缓冲区起始空闲空间大小--读偏移之前的空闲空间
uint64_t HeadIdleSize() { return _reader_idx; }
//获取可读数据大小 = 写偏移 - 读偏移
uint64_t ReadAbleSize() { return _writer_idx - _reader_idx; }
//将读偏移向后移动
void MoveReadOffset(uint64_t len) {
if (len == 0) return;
//向后移动的大小,必须小于可读数据大小
assert(len <= ReadAbleSize());
_reader_idx += len;
}
//将写偏移向后移动
void MoveWriteOffset(uint64_t len) {
//向后移动的大小,必须小于当前后边的空闲空间大小
assert(len <= TailIdleSize());
_writer_idx += len;
}
//确保可写空间足够(整体空闲空间够了就移动数据,否则就扩容)
void EnsureWriteSpace(uint64_t len) {
//如果末尾空闲空间大小足够,直接返回
if (TailIdleSize() >= len) { return; }
//末尾空闲空间不够,则判断加上起始位置的空闲空间大小是否足够, 够了就将数据移动到起始位置
if (len <= TailIdleSize() + HeadIdleSize()) {
//将数据移动到起始位置
uint64_t rsz = ReadAbleSize();//把当前数据大小先保存起来
std::copy(ReadPosition(), ReadPosition() + rsz, Begin());//把可读数据拷贝到起始位置
_reader_idx = 0; //将读偏移归0
_writer_idx = rsz; //将写位置置为可读数据大小, 因为当前的可读数据大小就是写偏移量
}else {
//总体空间不够,则需要扩容,不移动数据,直接给写偏移之后扩容足够空间即可
DBG_LOG("RESIZE %ld", _writer_idx + len);
_buffer.resize(_writer_idx + len);
}
}
//写入数据
void Write(const void *data, uint64_t len) {
//1. 保证有足够空间,2. 拷贝数据进去
if (len == 0) return;
EnsureWriteSpace(len);
const char *d = (const char *)data;
std::copy(d, d + len, WritePosition());
}
void WriteAndPush(const void *data, uint64_t len) {
Write(data, len);
MoveWriteOffset(len);
}
void WriteString(const std::string &data) {
return Write(data.c_str(), data.size());
}
void WriteStringAndPush(const std::string &data) {
WriteString(data);
MoveWriteOffset(data.size());
}
void WriteBuffer(Buffer &data) {
return Write(data.ReadPosition(), data.ReadAbleSize());
}
void WriteBufferAndPush(Buffer &data) {
WriteBuffer(data);
MoveWriteOffset(data.ReadAbleSize());
}
//读取数据
void Read(void *buf, uint64_t len) {
//要求要获取的数据大小必须小于可读数据大小
assert(len <= ReadAbleSize());
std::copy(ReadPosition(), ReadPosition() + len, (char*)buf);
}
void ReadAndPop(void *buf, uint64_t len) {
Read(buf, len);
MoveReadOffset(len);
}
std::string ReadAsString(uint64_t len) {
//要求要获取的数据大小必须小于可读数据大小
assert(len <= ReadAbleSize());
std::string str;
str.resize(len);
Read(&str[0], len);
return str;
}
std::string ReadAsStringAndPop(uint64_t len) {
assert(len <= ReadAbleSize());
std::string str = ReadAsString(len);
MoveReadOffset(len);
return str;
}
char *FindCRLF() {
char *res = (char*)memchr(ReadPosition(), '\n', ReadAbleSize());
return res;
}
/*通常获取一行数据,这种情况针对是*/
std::string GetLine() {
char *pos = FindCRLF();
if (pos == NULL) {
return "";
}
// +1是为了把换行字符也取出来。
return ReadAsString(pos - ReadPosition() + 1);
}
std::string GetLineAndPop() {
std::string str = GetLine();
MoveReadOffset(str.size());
return str;
}
//清空缓冲区
void Clear() {
//只需要将偏移量归0即可
_reader_idx = 0;
_writer_idx = 0;
}
};
#define MAX_LISTEN 1024
class Socket {
private:
int _sockfd;
public:
Socket():_sockfd(-1) {}
Socket(int fd): _sockfd(fd) {}
~Socket() { Close(); }
int Fd() { return _sockfd; }
//创建套接字
bool Create() {
// int socket(int domain, int type, int protocol)
_sockfd = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
if (_sockfd < 0) {
ERR_LOG("CREATE SOCKET FAILED!!");
return false;
}
return true;
}
//绑定地址信息
bool Bind(const std::string &ip, uint16_t port) {
struct sockaddr_in addr;
addr.sin_family = AF_INET;
addr.sin_port = htons(port);
addr.sin_addr.s_addr = inet_addr(ip.c_str());
socklen_t len = sizeof(struct sockaddr_in);
// int bind(int sockfd, struct sockaddr*addr, socklen_t len);
int ret = bind(_sockfd, (struct sockaddr*)&addr, len);
if (ret < 0) {
ERR_LOG("BIND ADDRESS FAILED!");
return false;
}
return true;
}
//开始监听
bool Listen(int backlog = MAX_LISTEN) {
// int listen(int backlog)
int ret = listen(_sockfd, backlog);
if (ret < 0) {
ERR_LOG("SOCKET LISTEN FAILED!");
return false;
}
return true;
}
//向服务器发起连接
bool Connect(const std::string &ip, uint16_t port) {
struct sockaddr_in addr;
addr.sin_family = AF_INET;
addr.sin_port = htons(port);
addr.sin_addr.s_addr = inet_addr(ip.c_str());
socklen_t len = sizeof(struct sockaddr_in);
// int connect(int sockfd, struct sockaddr*addr, socklen_t len);
int ret = connect(_sockfd, (struct sockaddr*)&addr, len);
if (ret < 0) {
ERR_LOG("CONNECT SERVER FAILED!");
return false;
}
return true;
}
//获取新连接
int Accept() {
// int accept(int sockfd, struct sockaddr *addr, socklen_t *len);
int newfd = accept(_sockfd, NULL, NULL);
if (newfd < 0) {
ERR_LOG("SOCKET ACCEPT FAILED!");
return -1;
}
return newfd;
}
//接收数据
ssize_t Recv(void *buf, size_t len, int flag = 0) {
// ssize_t recv(int sockfd, void *buf, size_t len, int flag);
ssize_t ret = recv(_sockfd, buf, len, flag);
if (ret <= 0) {
//EAGAIN 当前socket的接收缓冲区中没有数据了,在非阻塞的情况下才会有这个错误
//EINTR 表示当前socket的阻塞等待,被信号打断了,
if (errno == EAGAIN || errno == EINTR) {
return 0;//表示这次接收没有接收到数据
}
ERR_LOG("SOCKET RECV FAILED!!");
return -1;
}
return ret; //实际接收的数据长度
}
ssize_t NonBlockRecv(void *buf, size_t len) {
return Recv(buf, len, MSG_DONTWAIT); // MSG_DONTWAIT 表示当前接收为非阻塞。
}
//发送数据
ssize_t Send(const void *buf, size_t len, int flag = 0) {
// ssize_t send(int sockfd, void *data, size_t len, int flag);
ssize_t ret = send(_sockfd, buf, len, flag);
if (ret < 0) {
if (errno == EAGAIN || errno == EINTR) {
return 0;
}
ERR_LOG("SOCKET SEND FAILED!!");
return -1;
}
return ret;//实际发送的数据长度
}
ssize_t NonBlockSend(void *buf, size_t len) {
if (len == 0) return 0;
return Send(buf, len, MSG_DONTWAIT); // MSG_DONTWAIT 表示当前发送为非阻塞。
}
//关闭套接字
void Close() {
if (_sockfd != -1) {
close(_sockfd);
_sockfd = -1;
}
}
//创建一个服务端连接
bool CreateServer(uint16_t port, const std::string &ip = "0.0.0.0", bool block_flag = false) {
//1. 创建套接字,2. 绑定地址,3. 开始监听,4. 设置非阻塞, 5. 启动地址重用
if (Create() == false) return false;
if (block_flag) NonBlock();
if (Bind(ip, port) == false) return false;
if (Listen() == false) return false;
ReuseAddress();
return true;
}
//创建一个客户端连接
bool CreateClient(uint16_t port, const std::string &ip) {
//1. 创建套接字,2.指向连接服务器
if (Create() == false) return false;
if (Connect(ip, port) == false) return false;
return true;
}
//设置套接字选项---开启地址端口重用
void ReuseAddress() {
// int setsockopt(int fd, int leve, int optname, void *val, int vallen)
int val = 1;
setsockopt(_sockfd, SOL_SOCKET, SO_REUSEADDR, (void*)&val, sizeof(int));
val = 1;
setsockopt(_sockfd, SOL_SOCKET, SO_REUSEPORT, (void*)&val, sizeof(int));
}
//设置套接字阻塞属性-- 设置为非阻塞
void NonBlock() {
//int fcntl(int fd, int cmd, ... /* arg */ );
int flag = fcntl(_sockfd, F_GETFL, 0);
fcntl(_sockfd, F_SETFL, flag | O_NONBLOCK);
}
};
class Poller;
class EventLoop;
class Channel {
private:
int _fd;
EventLoop *_loop;
uint32_t _events; // 当前需要监控的事件
uint32_t _revents; // 当前连接触发的事件
using EventCallback = std::function<void()>;
EventCallback _read_callback; //可读事件被触发的回调函数
EventCallback _write_callback; //可写事件被触发的回调函数
EventCallback _error_callback; //错误事件被触发的回调函数
EventCallback _close_callback; //连接断开事件被触发的回调函数
EventCallback _event_callback; //任意事件被触发的回调函数
public:
Channel(EventLoop *loop, int fd):_fd(fd), _events(0), _revents(0), _loop(loop) {}
int Fd() { return _fd; }
uint32_t Events() { return _events; }//获取想要监控的事件
void SetREvents(uint32_t events) { _revents = events; }//设置实际就绪的事件
void SetReadCallback(const EventCallback &cb) { _read_callback = cb; }
void SetWriteCallback(const EventCallback &cb) { _write_callback = cb; }
void SetErrorCallback(const EventCallback &cb) { _error_callback = cb; }
void SetCloseCallback(const EventCallback &cb) { _close_callback = cb; }
void SetEventCallback(const EventCallback &cb) { _event_callback = cb; }
//当前是否监控了可读
bool ReadAble() { return (_events & EPOLLIN); }
//当前是否监控了可写
bool WriteAble() { return (_events & EPOLLOUT); }
//启动读事件监控
void EnableRead() { _events |= EPOLLIN; Update(); }
//启动写事件监控
void EnableWrite() { _events |= EPOLLOUT; Update(); }
//关闭读事件监控
void DisableRead() { _events &= ~EPOLLIN; Update(); }
//关闭写事件监控
void DisableWrite() { _events &= ~EPOLLOUT; Update(); }
//关闭所有事件监控
void DisableAll() { _events = 0; Update(); }
//移除监控
void Remove();
void Update();
//事件处理,一旦连接触发了事件,就调用这个函数,自己触发了什么事件如何处理自己决定
void HandleEvent() {
if ((_revents & EPOLLIN) || (_revents & EPOLLRDHUP) || (_revents & EPOLLPRI)) {
/*不管任何事件,都调用的回调函数*/
if (_read_callback) _read_callback();
}
/*有可能会释放连接的操作事件,一次只处理一个*/
if (_revents & EPOLLOUT) {
if (_write_callback) _write_callback();
}else if (_revents & EPOLLERR) {
if (_error_callback) _error_callback();//一旦出错,就会释放连接,因此要放到前边调用任意回调
}else if (_revents & EPOLLHUP) {
if (_close_callback) _close_callback();
}
if (_event_callback) _event_callback();
}
};
#define MAX_EPOLLEVENTS 1024
class Poller {
private:
int _epfd;
struct epoll_event _evs[MAX_EPOLLEVENTS];
std::unordered_map<int, Channel *> _channels;
private:
//对epoll的直接操作
void Update(Channel *channel, int op) {
// int epoll_ctl(int epfd, int op, int fd, struct epoll_event *ev);
int fd = channel->Fd();
struct epoll_event ev;
ev.data.fd = fd;
ev.events = channel->Events();
int ret = epoll_ctl(_epfd, op, fd, &ev);
if (ret < 0) {
ERR_LOG("EPOLLCTL FAILED!");
}
return;
}
//判断一个Channel是否已经添加了事件监控
bool HasChannel(Channel *channel) {
auto it = _channels.find(channel->Fd());
if (it == _channels.end()) {
return false;
}
return true;
}
public:
Poller() {
_epfd = epoll_create(MAX_EPOLLEVENTS);
if (_epfd < 0) {
ERR_LOG("EPOLL CREATE FAILED!!");
abort();//退出程序
}
}
//添加或修改监控事件
void UpdateEvent(Channel *channel) {
bool ret = HasChannel(channel);
if (ret == false) {
//不存在则添加
_channels.insert(std::make_pair(channel->Fd(), channel));
return Update(channel, EPOLL_CTL_ADD);
}
return Update(channel, EPOLL_CTL_MOD);
}
//移除监控
void RemoveEvent(Channel *channel) {
auto it = _channels.find(channel->Fd());
if (it != _channels.end()) {
_channels.erase(it);
}
Update(channel, EPOLL_CTL_DEL);
}
//开始监控,返回活跃连接
void Poll(std::vector<Channel*> *active) {
// int epoll_wait(int epfd, struct epoll_event *evs, int maxevents, int timeout)
int nfds = epoll_wait(_epfd, _evs, MAX_EPOLLEVENTS, -1);
if (nfds < 0) {
if (errno == EINTR) {
return ;
}
ERR_LOG("EPOLL WAIT ERROR:%s\n", strerror(errno));
abort();//退出程序
}
for (int i = 0; i < nfds; i++) {
auto it = _channels.find(_evs[i].data.fd);
assert(it != _channels.end());
it->second->SetREvents(_evs[i].events);//设置实际就绪的事件
active->push_back(it->second);
}
return;
}
};
using TaskFunc = std::function<void()>;
using ReleaseFunc = std::function<void()>;
class TimerTask{
private:
uint64_t _id; // 定时器任务对象ID
uint32_t _timeout; //定时任务的超时时间
bool _canceled; // false-表示没有被取消, true-表示被取消
TaskFunc _task_cb; //定时器对象要执行的定时任务
ReleaseFunc _release; //用于删除TimerWheel中保存的定时器对象信息
public:
TimerTask(uint64_t id, uint32_t delay, const TaskFunc &cb):
_id(id), _timeout(delay), _task_cb(cb), _canceled(false) {}
~TimerTask() {
if (_canceled == false) _task_cb();
_release();
}
void Cancel() { _canceled = true; }
void SetRelease(const ReleaseFunc &cb) { _release = cb; }
uint32_t DelayTime() { return _timeout; }
};
class TimerWheel {
private:
using WeakTask = std::weak_ptr<TimerTask>;
using PtrTask = std::shared_ptr<TimerTask>;
int _tick; //当前的秒针,走到哪里释放哪里,释放哪里,就相当于执行哪里的任务
int _capacity; //表盘最大数量---其实就是最大延迟时间
std::vector<std::vector<PtrTask>> _wheel;
std::unordered_map<uint64_t, WeakTask> _timers;
EventLoop *_loop;
int _timerfd;//定时器描述符--可读事件回调就是读取计数器,执行定时任务
std::unique_ptr<Channel> _timer_channel;
private:
void RemoveTimer(uint64_t id) {
auto it = _timers.find(id);
if (it != _timers.end()) {
_timers.erase(it);
}
}
static int CreateTimerfd() {
int timerfd = timerfd_create(CLOCK_MONOTONIC, 0);
if (timerfd < 0) {
ERR_LOG("TIMERFD CREATE FAILED!");
abort();
}
//int timerfd_settime(int fd, int flags, struct itimerspec *new, struct itimerspec *old);
struct itimerspec itime;
itime.it_value.tv_sec = 1;
itime.it_value.tv_nsec = 0;//第一次超时时间为1s后
itime.it_interval.tv_sec = 1;
itime.it_interval.tv_nsec = 0; //第一次超时后,每次超时的间隔时
timerfd_settime(timerfd, 0, &itime, NULL);
return timerfd;
}
int ReadTimefd() {
uint64_t times;
//有可能因为其他描述符的事件处理花费事件比较长,然后在处理定时器描述符事件的时候,有可能就已经超时了很多次
//read读取到的数据times就是从上一次read之后超时的次数
int ret = read(_timerfd, ×, 8);
if (ret < 0) {
ERR_LOG("READ TIMEFD FAILED!");
abort();
}
return times;
}
//这个函数应该每秒钟被执行一次,相当于秒针向后走了一步
void RunTimerTask() {
_tick = (_tick + 1) % _capacity;
_wheel[_tick].clear();//清空指定位置的数组,就会把数组中保存的所有管理定时器对象的shared_ptr释放掉
}
void OnTime() {
//根据实际超时的次数,执行对应的超时任务
int times = ReadTimefd();
for (int i = 0; i < times; i++) {
RunTimerTask();
}
}
void TimerAddInLoop(uint64_t id, uint32_t delay, const TaskFunc &cb) {
PtrTask pt(new TimerTask(id, delay, cb));
pt->SetRelease(std::bind(&TimerWheel::RemoveTimer, this, id));
int pos = (_tick + delay) % _capacity;
_wheel[pos].push_back(pt);
_timers[id] = WeakTask(pt);
}
void TimerRefreshInLoop(uint64_t id) {
//通过保存的定时器对象的weak_ptr构造一个shared_ptr出来,添加到轮子中
auto it = _timers.find(id);
if (it == _timers.end()) {
return;//没找着定时任务,没法刷新,没法延迟
}
PtrTask pt = it->second.lock();//lock获取weak_ptr管理的对象对应的shared_ptr
int delay = pt->DelayTime();
int pos = (_tick + delay) % _capacity;
_wheel[pos].push_back(pt);
}
void TimerCancelInLoop(uint64_t id) {
auto it = _timers.find(id);
if (it == _timers.end()) {
return;//没找着定时任务,没法刷新,没法延迟
}
PtrTask pt = it->second.lock();
if (pt) pt->Cancel();
}
public:
TimerWheel(EventLoop *loop):_capacity(60), _tick(0), _wheel(_capacity), _loop(loop),
_timerfd(CreateTimerfd()), _timer_channel(new Channel(_loop, _timerfd)) {
_timer_channel->SetReadCallback(std::bind(&TimerWheel::OnTime, this));
_timer_channel->EnableRead();//启动读事件监控
}
/*定时器中有个_timers成员,定时器信息的操作有可能在多线程中进行,因此需要考虑线程安全问题*/
/*如果不想加锁,那就把对定期的所有操作,都放到一个线程中进行*/
void TimerAdd(uint64_t id, uint32_t delay, const TaskFunc &cb);
//刷新/延迟定时任务
void TimerRefresh(uint64_t id);
void TimerCancel(uint64_t id);
/*这个接口存在线程安全问题--这个接口实际上不能被外界使用者调用,只能在模块内,在对应的EventLoop线程内执行*/
bool HasTimer(uint64_t id) {
auto it = _timers.find(id);
if (it == _timers.end()) {
return false;
}
return true;
}
};
class EventLoop {
private:
using Functor = std::function<void()>;
std::thread::id _thread_id;//线程ID
int _event_fd;//eventfd唤醒IO事件监控有可能导致的阻塞
std::unique_ptr<Channel> _event_channel;
Poller _poller;//进行所有描述符的事件监控
std::vector<Functor> _tasks;//任务池
std::mutex _mutex;//实现任务池操作的线程安全
TimerWheel _timer_wheel;//定时器模块
public:
//执行任务池中的所有任务
void RunAllTask() {
std::vector<Functor> functor;
{
std::unique_lock<std::mutex> _lock(_mutex);
_tasks.swap(functor);
}
for (auto &f : functor) {
f();
}
return ;
}
static int CreateEventFd() {
int efd = eventfd(0, EFD_CLOEXEC | EFD_NONBLOCK);
if (efd < 0) {
ERR_LOG("CREATE EVENTFD FAILED!!");
abort();//让程序异常退出
}
return efd;
}
void ReadEventfd() {
uint64_t res = 0;
int ret = read(_event_fd, &res, sizeof(res));
if (ret < 0) {
//EINTR -- 被信号打断; EAGAIN -- 表示无数据可读
if (errno == EINTR || errno == EAGAIN) {
return;
}
ERR_LOG("READ EVENTFD FAILED!");
abort();
}
return ;
}
void WeakUpEventFd() {
uint64_t val = 1;
int ret = write(_event_fd, &val, sizeof(val));
if (ret < 0) {
if (errno == EINTR) {
return;
}
ERR_LOG("READ EVENTFD FAILED!");
abort();
}
return ;
}
public:
EventLoop():_thread_id(std::this_thread::get_id()),
_event_fd(CreateEventFd()),
_event_channel(new Channel(this, _event_fd)),
_timer_wheel(this) {
//给eventfd添加可读事件回调函数,读取eventfd事件通知次数
_event_channel->SetReadCallback(std::bind(&EventLoop::ReadEventfd, this));
//启动eventfd的读事件监控
_event_channel->EnableRead();
}
//三步走--事件监控-》就绪事件处理-》执行任务
void Start() {
while(1) {
//1. 事件监控,
std::vector<Channel *> actives;
_poller.Poll(&actives);
//2. 事件处理。
for (auto &channel : actives) {
channel->HandleEvent();
}
//3. 执行任务
RunAllTask();
}
}
//用于判断当前线程是否是EventLoop对应的线程;
bool IsInLoop() {
return (_thread_id == std::this_thread::get_id());
}
void AssertInLoop() {
assert(_thread_id == std::this_thread::get_id());
}
//判断将要执行的任务是否处于当前线程中,如果是则执行,不是则压入队列。
void RunInLoop(const Functor &cb) {
if (IsInLoop()) {
return cb();
}
return QueueInLoop(cb);
}
//将操作压入任务池
void QueueInLoop(const Functor &cb) {
{
std::unique_lock<std::mutex> _lock(_mutex);
_tasks.push_back(cb);
}
//唤醒有可能因为没有事件就绪,而导致的epoll阻塞;
//其实就是给eventfd写入一个数据,eventfd就会触发可读事件
WeakUpEventFd();
}
//添加/修改描述符的事件监控
void UpdateEvent(Channel *channel) { return _poller.UpdateEvent(channel); }
//移除描述符的监控
void RemoveEvent(Channel *channel) { return _poller.RemoveEvent(channel); }
void TimerAdd(uint64_t id, uint32_t delay, const TaskFunc &cb) { return _timer_wheel.TimerAdd(id, delay, cb); }
void TimerRefresh(uint64_t id) { return _timer_wheel.TimerRefresh(id); }
void TimerCancel(uint64_t id) { return _timer_wheel.TimerCancel(id); }
bool HasTimer(uint64_t id) { return _timer_wheel.HasTimer(id); }
};
class LoopThread {
private:
/*用于实现_loop获取的同步关系,避免线程创建了,但是_loop还没有实例化之前去获取_loop*/
std::mutex _mutex; // 互斥锁
std::condition_variable _cond; // 条件变量
EventLoop *_loop; // EventLoop指针变量,这个对象需要在线程内实例化
std::thread _thread; // EventLoop对应的线程
private:
/*实例化 EventLoop 对象,唤醒_cond上有可能阻塞的线程,并且开始运行EventLoop模块的功能*/
void ThreadEntry() {
EventLoop loop;
{
std::unique_lock<std::mutex> lock(_mutex);//加锁
_loop = &loop;
_cond.notify_all();
}
loop.Start();
}
public:
/*创建线程,设定线程入口函数*/
LoopThread():_loop(NULL), _thread(std::thread(&LoopThread::ThreadEntry, this)) {}
/*返回当前线程关联的EventLoop对象指针*/
EventLoop *GetLoop() {
EventLoop *loop = NULL;
{
std::unique_lock<std::mutex> lock(_mutex);//加锁
_cond.wait(lock, [&](){ return _loop != NULL; });//loop为NULL就一直阻塞
loop = _loop;
}
return loop;
}
};
class LoopThreadPool {
private:
int _thread_count;
int _next_idx;
EventLoop *_baseloop;
std::vector<LoopThread*> _threads;
std::vector<EventLoop *> _loops;
public:
LoopThreadPool(EventLoop *baseloop):_thread_count(0), _next_idx(0), _baseloop(baseloop) {}
void SetThreadCount(int count) { _thread_count = count; }
void Create() {
if (_thread_count > 0) {
_threads.resize(_thread_count);
_loops.resize(_thread_count);
for (int i = 0; i < _thread_count; i++) {
_threads[i] = new LoopThread();
_loops[i] = _threads[i]->GetLoop();
}
}
return ;
}
EventLoop *NextLoop() {
if (_thread_count == 0) {
return _baseloop;
}
_next_idx = (_next_idx + 1) % _thread_count;
return _loops[_next_idx];
}
};
class Any{
private:
class holder {
public:
virtual ~holder() {}
virtual const std::type_info& type() = 0;
virtual holder *clone() = 0;
};
template<class T>
class placeholder: public holder {
public:
placeholder(const T &val): _val(val) {}
// 获取子类对象保存的数据类型
virtual const std::type_info& type() { return typeid(T); }
// 针对当前的对象自身,克隆出一个新的子类对象
virtual holder *clone() { return new placeholder(_val); }
public:
T _val;
};
holder *_content;
public:
Any():_content(NULL) {}
template<class T>
Any(const T &val):_content(new placeholder<T>(val)) {}
Any(const Any &other):_content(other._content ? other._content->clone() : NULL) {}
~Any() { delete _content; }
Any &swap(Any &other) {
std::swap(_content, other._content);
return *this;
}
// 返回子类对象保存的数据的指针
template<class T>
T *get() {
//想要获取的数据类型,必须和保存的数据类型一致
assert(typeid(T) == _content->type());
return &((placeholder<T>*)_content)->_val;
}
//赋值运算符的重载函数
template<class T>
Any& operator=(const T &val) {
//为val构造一个临时的通用容器,然后与当前容器自身进行指针交换,临时对象释放的时候,原先保存的数据也就被释放
Any(val).swap(*this);
return *this;
}
Any& operator=(const Any &other) {
Any(other).swap(*this);
return *this;
}
};
class Connection;
//DISCONECTED -- 连接关闭状态; CONNECTING -- 连接建立成功-待处理状态
//CONNECTED -- 连接建立完成,各种设置已完成,可以通信的状态; DISCONNECTING -- 待关闭状态
typedef enum { DISCONNECTED, CONNECTING, CONNECTED, DISCONNECTING}ConnStatu;
using PtrConnection = std::shared_ptr<Connection>;
class Connection : public std::enable_shared_from_this<Connection> {
private:
uint64_t _conn_id; // 连接的唯一ID,便于连接的管理和查找
//uint64_t _timer_id; //定时器ID,必须是唯一的,这块为了简化操作使用conn_id作为定时器ID
int _sockfd; // 连接关联的文件描述符
bool _enable_inactive_release; // 连接是否启动非活跃销毁的判断标志,默认为false
EventLoop *_loop; // 连接所关联的一个EventLoop
ConnStatu _statu; // 连接状态
Socket _socket; // 套接字操作管理
Channel _channel; // 连接的事件管理
Buffer _in_buffer; // 输入缓冲区---存放从socket中读取到的数据
Buffer _out_buffer; // 输出缓冲区---存放要发送给对端的数据
Any _context; // 请求的接收处理上下文
/*这四个回调函数,是让服务器模块来设置的(其实服务器模块的处理回调也是组件使用者设置的)*/
/*换句话说,这几个回调都是组件使用者使用的*/
using ConnectedCallback = std::function<void(const PtrConnection&)>;
using MessageCallback = std::function<void(const PtrConnection&, Buffer *)>;
using ClosedCallback = std::function<void(const PtrConnection&)>;
using AnyEventCallback = std::function<void(const PtrConnection&)>;
ConnectedCallback _connected_callback;
MessageCallback _message_callback;
ClosedCallback _closed_callback;
AnyEventCallback _event_callback;
/*组件内的连接关闭回调--组件内设置的,因为服务器组件内会把所有的连接管理起来,一旦某个连接要关闭*/
/*就应该从管理的地方移除掉自己的信息*/
ClosedCallback _server_closed_callback;
private:
/*五个channel的事件回调函数*/
//描述符可读事件触发后调用的函数,接收socket数据放到接收缓冲区中,然后调用_message_callback
void HandleRead() {
//1. 接收socket的数据,放到缓冲区
char buf[65536];
ssize_t ret = _socket.NonBlockRecv(buf, 65535);
if (ret < 0) {
//出错了,不能直接关闭连接
return ShutdownInLoop();
}
//这里的等于0表示的是没有读取到数据,而并不是连接断开了,连接断开返回的是-1
//将数据放入输入缓冲区,写入之后顺便将写偏移向后移动
_in_buffer.WriteAndPush(buf, ret);
//2. 调用message_callback进行业务处理
if (_in_buffer.ReadAbleSize() > 0) {
//shared_from_this--从当前对象自身获取自身的shared_ptr管理对象
return _message_callback(shared_from_this(), &_in_buffer);
}
}
//描述符可写事件触发后调用的函数,将发送缓冲区中的数据进行发送
void HandleWrite() {
//_out_buffer中保存的数据就是要发送的数据
ssize_t ret = _socket.NonBlockSend(_out_buffer.ReadPosition(), _out_buffer.ReadAbleSize());
if (ret < 0) {
//发送错误就该关闭连接了,
if (_in_buffer.ReadAbleSize() > 0) {
_message_callback(shared_from_this(), &_in_buffer);
}
return Release();//这时候就是实际的关闭释放操作了。
}
_out_buffer.MoveReadOffset(ret);//千万不要忘了,将读偏移向后移动
if (_out_buffer.ReadAbleSize() == 0) {
_channel.DisableWrite();// 没有数据待发送了,关闭写事件监控
//如果当前是连接待关闭状态,则有数据,发送完数据释放连接,没有数据则直接释放
if (_statu == DISCONNECTING) {
return Release();
}
}
return;
}
//描述符触发挂断事件
void HandleClose() {
/*一旦连接挂断了,套接字就什么都干不了了,因此有数据待处理就处理一下,完毕关闭连接*/
if (_in_buffer.ReadAbleSize() > 0) {
_message_callback(shared_from_this(), &_in_buffer);
}
return Release();
}
//描述符触发出错事件
void HandleError() {
return HandleClose();
}
//描述符触发任意事件: 1. 刷新连接的活跃度--延迟定时销毁任务; 2. 调用组件使用者的任意事件回调
void HandleEvent() {
if (_enable_inactive_release == true) { _loop->TimerRefresh(_conn_id); }
if (_event_callback) { _event_callback(shared_from_this()); }
}
//连接获取之后,所处的状态下要进行各种设置(启动读监控,调用回调函数)
void EstablishedInLoop() {
// 1. 修改连接状态; 2. 启动读事件监控; 3. 调用回调函数
assert(_statu == CONNECTING);//当前的状态必须一定是上层的半连接状态
_statu = CONNECTED;//当前函数执行完毕,则连接进入已完成连接状态
// 一旦启动读事件监控就有可能会立即触发读事件,如果这时候启动了非活跃连接销毁
_channel.EnableRead();
if (_connected_callback) _connected_callback(shared_from_this());
}
//这个接口才是实际的释放接口
void ReleaseInLoop() {
//1. 修改连接状态,将其置为DISCONNECTED
_statu = DISCONNECTED;
//2. 移除连接的事件监控
_channel.Remove();
//3. 关闭描述符
_socket.Close();
//4. 如果当前定时器队列中还有定时销毁任务,则取消任务
if (_loop->HasTimer(_conn_id)) CancelInactiveReleaseInLoop();
//5. 调用关闭回调函数,避免先移除服务器管理的连接信息导致Connection被释放,再去处理会出错,因此先调用用户的回调函数
if (_closed_callback) _closed_callback(shared_from_this());
//移除服务器内部管理的连接信息
if (_server_closed_callback) _server_closed_callback(shared_from_this());
}
//这个接口并不是实际的发送接口,而只是把数据放到了发送缓冲区,启动了可写事件监控
void SendInLoop(Buffer &buf) {
if (_statu == DISCONNECTED) return ;
_out_buffer.WriteBufferAndPush(buf);
if (_channel.WriteAble() == false) {
_channel.EnableWrite();
}
}
//这个关闭操作并非实际的连接释放操作,需要判断还有没有数据待处理,待发送
void ShutdownInLoop() {
_statu = DISCONNECTING;// 设置连接为半关闭状态
if (_in_buffer.ReadAbleSize() > 0) {
if (_message_callback) _message_callback(shared_from_this(), &_in_buffer);
}
//要么就是写入数据的时候出错关闭,要么就是没有待发送数据,直接关闭
if (_out_buffer.ReadAbleSize() > 0) {
if (_channel.WriteAble() == false) {
_channel.EnableWrite();
}
}
if (_out_buffer.ReadAbleSize() == 0) {
Release();
}
}
//启动非活跃连接超时释放规则
void EnableInactiveReleaseInLoop(int sec) {
//1. 将判断标志 _enable_inactive_release 置为true
_enable_inactive_release = true;
//2. 如果当前定时销毁任务已经存在,那就刷新延迟一下即可
if (_loop->HasTimer(_conn_id)) {
return _loop->TimerRefresh(_conn_id);
}
//3. 如果不存在定时销毁任务,则新增
_loop->TimerAdd(_conn_id, sec, std::bind(&Connection::Release, this));
}
void CancelInactiveReleaseInLoop() {
_enable_inactive_release = false;
if (_loop->HasTimer(_conn_id)) {
_loop->TimerCancel(_conn_id);
}
}
void UpgradeInLoop(const Any &context,
const ConnectedCallback &conn,
const MessageCallback &msg,
const ClosedCallback &closed,
const AnyEventCallback &event) {
_context = context;
_connected_callback = conn;
_message_callback = msg;
_closed_callback = closed;
_event_callback = event;
}
public:
Connection(EventLoop *loop, uint64_t conn_id, int sockfd):_conn_id(conn_id), _sockfd(sockfd),
_enable_inactive_release(false), _loop(loop), _statu(CONNECTING), _socket(_sockfd),
_channel(loop, _sockfd) {
_channel.SetCloseCallback(std::bind(&Connection::HandleClose, this));
_channel.SetEventCallback(std::bind(&Connection::HandleEvent, this));
_channel.SetReadCallback(std::bind(&Connection::HandleRead, this));
_channel.SetWriteCallback(std::bind(&Connection::HandleWrite, this));
_channel.SetErrorCallback(std::bind(&Connection::HandleError, this));
}
~Connection() { DBG_LOG("RELEASE CONNECTION:%p", this); }
//获取管理的文件描述符
int Fd() { return _sockfd; }
//获取连接ID
int Id() { return _conn_id; }
//是否处于CONNECTED状态
bool Connected() { return (_statu == CONNECTED); }
//设置上下文--连接建立完成时进行调用
void SetContext(const Any &context) { _context = context; }
//获取上下文,返回的是指针
Any *GetContext() { return &_context; }
void SetConnectedCallback(const ConnectedCallback&cb) { _connected_callback = cb; }
void SetMessageCallback(const MessageCallback&cb) { _message_callback = cb; }
void SetClosedCallback(const ClosedCallback&cb) { _closed_callback = cb; }
void SetAnyEventCallback(const AnyEventCallback&cb) { _event_callback = cb; }
void SetSrvClosedCallback(const ClosedCallback&cb) { _server_closed_callback = cb; }
//连接建立就绪后,进行channel回调设置,启动读监控,调用_connected_callback
void Established() {
_loop->RunInLoop(std::bind(&Connection::EstablishedInLoop, this));
}
//发送数据,将数据放到发送缓冲区,启动写事件监控
void Send(const char *data, size_t len) {
//外界传入的data,可能是个临时的空间,我们现在只是把发送操作压入了任务池,有可能并没有被立即执行
//因此有可能执行的时候,data指向的空间有可能已经被释放了。
Buffer buf;
buf.WriteAndPush(data, len);
_loop->RunInLoop(std::bind(&Connection::SendInLoop, this, std::move(buf)));
}
//提供给组件使用者的关闭接口--并不实际关闭,需要判断有没有数据待处理
void Shutdown() {
_loop->RunInLoop(std::bind(&Connection::ShutdownInLoop, this));
}
void Release() {
_loop->QueueInLoop(std::bind(&Connection::ReleaseInLoop, this));
}
//启动非活跃销毁,并定义多长时间无通信就是非活跃,添加定时任务
void EnableInactiveRelease(int sec) {
_loop->RunInLoop(std::bind(&Connection::EnableInactiveReleaseInLoop, this, sec));
}
//取消非活跃销毁
void CancelInactiveRelease() {
_loop->RunInLoop(std::bind(&Connection::CancelInactiveReleaseInLoop, this));
}
//切换协议---重置上下文以及阶段性回调处理函数 -- 而是这个接口必须在EventLoop线程中立即执行
//防备新的事件触发后,处理的时候,切换任务还没有被执行--会导致数据使用原协议处理了。
void Upgrade(const Any &context, const ConnectedCallback &conn, const MessageCallback &msg,
const ClosedCallback &closed, const AnyEventCallback &event) {
_loop->AssertInLoop();
_loop->RunInLoop(std::bind(&Connection::UpgradeInLoop, this, context, conn, msg, closed, event));
}
};
class Acceptor {
private:
Socket _socket;//用于创建监听套接字
EventLoop *_loop; //用于对监听套接字进行事件监控
Channel _channel; //用于对监听套接字进行事件管理
using AcceptCallback = std::function<void(int)>;
AcceptCallback _accept_callback;
private:
/*监听套接字的读事件回调处理函数---获取新连接,调用_accept_callback函数进行新连接处理*/
void HandleRead() {
int newfd = _socket.Accept();
if (newfd < 0) {
return ;
}
if (_accept_callback) _accept_callback(newfd);
}
int CreateServer(int port) {
bool ret = _socket.CreateServer(port);
assert(ret == true);
return _socket.Fd();
}
public:
/*不能将启动读事件监控,放到构造函数中,必须在设置回调函数后,再去启动*/
/*否则有可能造成启动监控后,立即有事件,处理的时候,回调函数还没设置:新连接得不到处理,且资源泄漏*/
Acceptor(EventLoop *loop, int port): _socket(CreateServer(port)), _loop(loop),
_channel(loop, _socket.Fd()) {
_channel.SetReadCallback(std::bind(&Acceptor::HandleRead, this));
}
void SetAcceptCallback(const AcceptCallback &cb) { _accept_callback = cb; }
void Listen() { _channel.EnableRead(); }
};
class TcpServer {
private:
uint64_t _next_id; //这是一个自动增长的连接ID,
int _port;
int _timeout; //这是非活跃连接的统计时间---多长时间无通信就是非活跃连接
bool _enable_inactive_release;//是否启动了非活跃连接超时销毁的判断标志
EventLoop _baseloop; //这是主线程的EventLoop对象,负责监听事件的处理
Acceptor _acceptor; //这是监听套接字的管理对象
LoopThreadPool _pool; //这是从属EventLoop线程池
std::unordered_map<uint64_t, PtrConnection> _conns;//保存管理所有连接对应的shared_ptr对象
using ConnectedCallback = std::function<void(const PtrConnection&)>;
using MessageCallback = std::function<void(const PtrConnection&, Buffer *)>;
using ClosedCallback = std::function<void(const PtrConnection&)>;
using AnyEventCallback = std::function<void(const PtrConnection&)>;
using Functor = std::function<void()>;
ConnectedCallback _connected_callback;
MessageCallback _message_callback;
ClosedCallback _closed_callback;
AnyEventCallback _event_callback;
private:
void RunAfterInLoop(const Functor &task, int delay) {
_next_id++;
_baseloop.TimerAdd(_next_id, delay, task);
}
//为新连接构造一个Connection进行管理
void NewConnection(int fd) {
_next_id++;
PtrConnection conn(new Connection(_pool.NextLoop(), _next_id, fd));
conn->SetMessageCallback(_message_callback);
conn->SetClosedCallback(_closed_callback);
conn->SetConnectedCallback(_connected_callback);
conn->SetAnyEventCallback(_event_callback);
conn->SetSrvClosedCallback(std::bind(&TcpServer::RemoveConnection, this, std::placeholders::_1));
if (_enable_inactive_release) conn->EnableInactiveRelease(_timeout);//启动非活跃超时销毁
conn->Established();//就绪初始化
_conns.insert(std::make_pair(_next_id, conn));
}
void RemoveConnectionInLoop(const PtrConnection &conn) {
int id = conn->Id();
auto it = _conns.find(id);
if (it != _conns.end()) {
_conns.erase(it);
}
}
//从管理Connection的_conns中移除连接信息
void RemoveConnection(const PtrConnection &conn) {
_baseloop.RunInLoop(std::bind(&TcpServer::RemoveConnectionInLoop, this, conn));
}
public:
TcpServer(int port):
_port(port),
_next_id(0),
_enable_inactive_release(false),
_acceptor(&_baseloop, port),
_pool(&_baseloop) {
_acceptor.SetAcceptCallback(std::bind(&TcpServer::NewConnection, this, std::placeholders::_1));
_acceptor.Listen();//将监听套接字挂到baseloop上
}
void SetThreadCount(int count) { return _pool.SetThreadCount(count); }
void SetConnectedCallback(const ConnectedCallback&cb) { _connected_callback = cb; }
void SetMessageCallback(const MessageCallback&cb) { _message_callback = cb; }
void SetClosedCallback(const ClosedCallback&cb) { _closed_callback = cb; }
void SetAnyEventCallback(const AnyEventCallback&cb) { _event_callback = cb; }
void EnableInactiveRelease(int timeout) { _timeout = timeout; _enable_inactive_release = true; }
//用于添加一个定时任务
void RunAfter(const Functor &task, int delay) {
_baseloop.RunInLoop(std::bind(&TcpServer::RunAfterInLoop, this, task, delay));
}
void Start() { _pool.Create(); _baseloop.Start(); }
};
void Channel::Remove() { return _loop->RemoveEvent(this); }
void Channel::Update() { return _loop->UpdateEvent(this); }
void TimerWheel::TimerAdd(uint64_t id, uint32_t delay, const TaskFunc &cb) {
_loop->RunInLoop(std::bind(&TimerWheel::TimerAddInLoop, this, id, delay, cb));
}
//刷新/延迟定时任务
void TimerWheel::TimerRefresh(uint64_t id) {
_loop->RunInLoop(std::bind(&TimerWheel::TimerRefreshInLoop, this, id));
}
void TimerWheel::TimerCancel(uint64_t id) {
_loop->RunInLoop(std::bind(&TimerWheel::TimerCancelInLoop, this, id));
}
class NetWork {
public:
NetWork() {
DBG_LOG("SIGPIPE INIT");
signal(SIGPIPE, SIG_IGN);
}
};
static NetWork nw;
#endif