文章目录
- 1. 什么是 Short Write 问题?
- 2. 如何解决 Short Write 问题?
- 2.1 方法 1:将 Socket 设置为阻塞模式
- 2.2 方法 2:用户态维护发送缓冲区
- 3. 用户态维护发送缓冲区实现
- 3.1 核心要点
- 3.2 代码实现
- 3.3 测试程序
- 参考文档
1. 什么是 Short Write 问题?
在 TCP 编程中,short write
问题指的是在调用 send
或 write
系统调用时,实际发送的数据量比预期要少。这通常是因为网络协议栈发送缓冲区的空间不足,导致不能一次性发送完整的数据。遇到这种情况时,系统调用会返回实际发送的字节数,并将 errno
设置为 EAGAIN
,表示缓冲区没有足够的空间来继续发送数据。
2. 如何解决 Short Write 问题?
针对 EPOLL 模型 中的 LT(Level Triggered)模式,可以采取以下几种方案来解决 short write
问题:
2.1 方法 1:将 Socket 设置为阻塞模式
将 Socket 设置为阻塞模式时,send
系统调用会一直阻塞,直到有足够的缓冲区空间发送完整的数据。这种方法能够避免 short write
问题,但会导致线程阻塞,从而影响性能。因此,通常不推荐在高并发或需要高吞吐量的场景中使用此方法。
2.2 方法 2:用户态维护发送缓冲区
更推荐的方法是用户态维护一个发送缓冲区,并结合 EPOLLONESHOT
和 EPOLLOUT
事件来控制数据发送。这种方法不需要阻塞线程,能够有效地处理 short write
问题。
3. 用户态维护发送缓冲区实现
3.1 核心要点
使用环形缓冲区来保存待发送的数据,当系统发送缓冲区不够时,数据会被存入环形缓冲区,并在后续的 EPOLLOUT
事件触发时继续发送。
- 环形缓冲区设计
环形缓冲区(Circular Buffer)是一个固定大小的缓存,用于暂存数据。当数据无法完全写入网络协议栈缓冲区时,可以将其暂存,并在缓冲区有足够空间时继续写入。通过注册 EPOLLOUT
事件,当缓冲区有空闲空间时,程序可以重新尝试发送数据。
-
数据收发管理类设计
-
asyncSend
:当数据网络协议栈发送缓冲区没有足够空间时,会将数据存储到环形缓冲区,并通过EPOLLONESHOT
和EPOLLOUT
事件确保数据能在后续时刻继续发送。 -
doSend
:此函数会被EPOLLOUT
事件触发,它从环形缓冲区中取出数据并尝试发送。如果发送成功,则释放相应的缓冲区空间;如果发送失败,且错误码为EAGAIN
或EINTER
,则会重试。
-
3.2 代码实现
#include <atomic>
#include <cstring>
#include <memory>
#include <sys/socket.h>
#include <sys/epoll.h>
#include <unistd.h>
#include <cstdio>
#include <mutex>
#include <cassert>
#include <fcntl.h>
class LockFreeBytesBuffer {
public:
static const std::size_t kBufferSize = 10240U; // 缓冲区大小
LockFreeBytesBuffer() noexcept : reader_index_(0U), writer_index_(0U) {
std::memset(buffer_, 0, kBufferSize);
}
bool append(const char* data, std::size_t length) noexcept;
std::size_t beginRead(const char** target) noexcept;
void endRead(std::size_t length) noexcept;
private:
char buffer_[kBufferSize];
std::atomic<std::size_t> reader_index_;
std::atomic<std::size_t> writer_index_;
};
bool LockFreeBytesBuffer::append(const char* data, std::size_t length) noexcept {
const std::size_t current_write_index = writer_index_.load(std::memory_order_relaxed);
const std::size_t current_read_index = reader_index_.load(std::memory_order_acquire);
const std::size_t free_space = (current_read_index + kBufferSize - current_write_index - 1U) % kBufferSize;
if (length > free_space) {
return false; // 缓冲区满
}
const std::size_t pos = current_write_index % kBufferSize;
const std::size_t first_part = std::min(length, kBufferSize - pos);
std::memcpy(&buffer_[pos], data, first_part);
std::memcpy(&buffer_[0], data + first_part, length - first_part);
writer_index_.store(current_write_index + length, std::memory_order_release);
return true;
}
std::size_t LockFreeBytesBuffer::beginRead(const char** target) noexcept {
const std::size_t current_read_index = reader_index_.load(std::memory_order_relaxed);
const std::size_t current_write_index = writer_index_.load(std::memory_order_acquire);
const std::size_t available_data = (current_write_index - current_read_index) % kBufferSize;
if (available_data == 0U) {
return 0U; // 缓冲区空
}
const std::size_t pos = current_read_index % kBufferSize;
*target = &buffer_[pos];
return std::min(available_data, kBufferSize - pos);
}
void LockFreeBytesBuffer::endRead(std::size_t length) noexcept {
const std::size_t current_read_index = reader_index_.load(std::memory_order_relaxed);
reader_index_.store(current_read_index + length, std::memory_order_release);
}
class SocketContext {
public:
SocketContext(int epoll_fd, int sock_fd)
: epoll_fd_(epoll_fd), sock_fd_(sock_fd) {
setNonblocking();
addFd();
}
~SocketContext() {
removeFd();
close(sock_fd_);
}
bool asyncSend(const char* data, int size) {
bool result = buffer_.append(data, static_cast<std::size_t>(size));
if (result) {
modifyEvent(false, true); // 修改 EPOLLONESHOT 和 EPOLLOUT
}
return result;
}
int doRecv() {
char buffer[1024] = {};
int count = read(sock_fd_, buffer, sizeof(buffer));
if (count <= 0) {
return count; // 读取失败或连接关闭
}
modifyEvent(true, false); // 恢复 EPOLLIN 事件
fprintf(stderr, "Received: %s\n", buffer);
return count;
}
int doSend() {
const char* pdata = nullptr;
std::size_t data_size = buffer_.beginRead(&pdata);
if (data_size == 0) {
return 0; // 没有数据可以发送
}
int send_size = send(sock_fd_, pdata, static_cast<int>(data_size), MSG_DONTWAIT);
if (send_size > 0) {
buffer_.endRead(static_cast<std::size_t>(send_size)); // 更新已发送数据
} else if (send_size == -1 && errno != EAGAIN) {
fprintf(stderr, "send failed, error: %s\n", strerror(errno));
}
return send_size;
}
protected:
void setNonblocking() {
int flags = fcntl(sock_fd_, F_GETFL, 0);
if (flags == -1) {
fprintf(stderr, "fcntl GETFL failed: %s\n", strerror(errno));
return;
}
fcntl(sock_fd_, F_SETFL, flags | O_NONBLOCK);
}
void addFd() {
struct epoll_event event;
event.data.ptr = this;
event.events = EPOLLIN | EPOLLONESHOT | EPOLLOUT;
if (epoll_ctl(epoll_fd_, EPOLL_CTL_ADD, sock_fd_, &event) == -1) {
fprintf(stderr, "epoll_ctl add failed: %s\n", strerror(errno));
}
}
void removeFd() {
epoll_ctl(epoll_fd_, EPOLL_CTL_DEL, sock_fd_, nullptr);
}
inline void modifyEvent(bool in_event = true, bool out_event = true) {
struct epoll_event event;
event.data.ptr = this;
event.events = EPOLLONESHOT;
if (in_event) {
event.events |= EPOLLIN;
}
if (out_event) {
event.events |= EPOLLOUT;
}
epoll_ctl(epoll_fd_, EPOLL_CTL_MOD, sock_fd_, &event);
}
private:
int epoll_fd_;
int sock_fd_;
LockFreeBytesBuffer buffer_;
};
代码说明
- 无锁环形缓冲区:
LockFreeBytesBuffer
类通过原子操作(std::atomic
)来确保线程安全,避免了传统的锁机制。
更多请见:C++生产者-消费者无锁缓冲区的简单实现 - 事件驱动机制:通过
EPOLLIN
和EPOLLOUT
事件来控制数据的接收和发送,避免了阻塞操作。 - 非阻塞发送:通过
send
函数的MSG_DONTWAIT
标志来确保发送操作不会阻塞,遇到EAGAIN
错误时会重试。
3.3 测试程序
#include <iostream>
#include <memory>
#include <cstring>
#include <unistd.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <sys/epoll.h>
#include <fcntl.h>
#include <cassert>
#define MAX_EVENTS 10
int createServerSocket(int port) {
int sockfd = socket(AF_INET, SOCK_STREAM, 0);
if (sockfd == -1) {
fprintf(stderr, "socket creation failed: %s\n", strerror(errno));
return -1;
}
sockaddr_in server_addr;
std::memset(&server_addr, 0, sizeof(server_addr));
server_addr.sin_family = AF_INET;
server_addr.sin_addr.s_addr = INADDR_ANY;
server_addr.sin_port = htons(port);
if (bind(sockfd, (struct sockaddr*)&server_addr, sizeof(server_addr)) == -1) {
fprintf(stderr, "bind failed: %s\n", strerror(errno));
close(sockfd);
return -1;
}
if (listen(sockfd, 5) == -1) {
fprintf(stderr, "listen failed: %s\n", strerror(errno));
close(sockfd);
return -1;
}
return sockfd;
}
int main() {
int epoll_fd = epoll_create1(0);
if (epoll_fd == -1) {
fprintf(stderr, "epoll_create1 failed: %s\n", strerror(errno));
return -1;
}
int server_fd = createServerSocket(8080);
if (server_fd == -1) {
return -1;
}
// Set the server socket to non-blocking mode
fcntl(server_fd, F_SETFL, O_NONBLOCK);
// Add server socket to epoll
struct epoll_event ev;
ev.events = EPOLLIN;
ev.data.fd = server_fd;
if (epoll_ctl(epoll_fd, EPOLL_CTL_ADD, server_fd, &ev) == -1) {
fprintf(stderr, "epoll_ctl: server_fd failed: %s\n", strerror(errno));
return -1;
}
fprintf(stderr, "Server listening on port 8080...\n");
while (true) {
struct epoll_event events[MAX_EVENTS];
int n = epoll_wait(epoll_fd, events, MAX_EVENTS, -1);
for (int i = 0; i < n; ++i) {
if (events[i].data.fd == server_fd) {
// Accept new client connection
int client_fd = accept(server_fd, NULL, NULL);
if (client_fd == -1) {
fprintf(stderr, "accept failed: %s\n", strerror(errno));
continue;
}
fcntl(client_fd, F_SETFL, O_NONBLOCK);
std::unique_ptr<SocketContext> client = std::make_unique<SocketContext>(epoll_fd, client_fd);
ev.events = EPOLLIN | EPOLLOUT | EPOLLONESHOT;
ev.data.ptr = client.get();
if (epoll_ctl(epoll_fd, EPOLL_CTL_ADD, client_fd, &ev) == -1) {
fprintf(stderr, "epoll_ctl: client_fd failed: %s\n", strerror(errno));
}
} else {
SocketContext* client = static_cast<SocketContext*>(events[i].data.ptr);
if (events[i].events & EPOLLIN) {
client->doRecv();
}
if (events[i].events & EPOLLOUT) {
client->doSend();
}
}
}
}
close(server_fd);
close(epoll_fd);
return 0;
}
参考文档
- tcp 解决short write问题
- C++生产者-消费者无锁缓冲区的简单实现