背景
我们正在使用Rust开发基于RAG的知识库系统,其中对于模型的回复使用了常用的SSE,Web框架使用Rocket,Rocket提供了一个简单的方式支持SSE,但没有会话保持、会话恢复等功能,因此我们自己简单实现这两个功能。
使用Rocket推送消息流
以下,是Rocket给出的示例:
#[get("/text/stream")]
fn stream() -> TextStream![String] {
TextStream! {
let mut interval = time::interval(Duration::from_secs(1));
for i in 0..10 {
yield format!("n: {}", i);
interval.tick().await;
}
}
}
我们需要改造这个示例,以满足将模型回复的消息推送给前端的需求。
首先,对于既然推流,需要知道将流推送给谁,也就是要推送到哪个会话中,所以我们在发起会话的时候,需要一个会话ID来标识一个唯一的会话。
我们使用sse.js这个库作为SSE的客户端,用于发起SSE连接,该库可以通过发起POST请求来建立连接,可以携带额外的数据和请求头。
使用以下结构来接收一个对话请求:
pub struct ChatMessageReq {
/// 会话ID
pub session_id: String,
/// 消息内容
pub content: String,
}
于是我们的接口需要修改为这样:
#[post("/chat", data = "<req>")]
async fn question(
req: Json<ChatMessageReq>) -> (ContentType, TextStream<impl Stream<Item = String>>)
{
//TODO
}
其中TextStream<impl Stream<Item = String>>
等价于TextStream![String]
。
需要注意的是,如果返回值没有指定ContentType
,那么Rocket默认响应的ContentType是文本类型,而非stream类型,会导致前端无法解析。
接下来我们实现会话管理功能。
我们定义了一个名为SsePool的struct,来存储并管理SSE连接:
struct SsePool {
/// 消息流传输通道
channel: DashMap<String, Sender<SseEvent>>,
}
impl SsePool {
/// 初始化连接池
pub fn init() -> Self {
Self {
channel: DashMap::new(),
}
}
/// 移除连接
fn remove(&self, id: &String) {
if let Some((_, sender)) = self.channel.remove(id) {
drop(sender);
}
}
/// 获取连接
fn get_sender(&self, id: &String) -> Option<Sender<SseEvent>> {
self.channel.get(id).map(|v| v.value().clone())
}
/// 新建channel
pub fn new_channel(&self, id: String) -> (Sender<SseEvent>, Receiver<SseEvent>) {
let (sender, receiver) = tokio::sync::mpsc::channel(10_0000);
// 获取并移除旧sender
let old_sender = self.channel.remove(&id).map(|(_, s)| s);
// 插入新sender
self.channel.insert(id, sender.clone());
// 处理旧sender
if let Some(old_sender) = old_sender {
tokio::spawn(async move {
// 发送终止信号
let _ = old_sender.send(SseEvent::Abort).await;
});
}
(sender, receiver)
}
/// 发送消息
pub async fn send_message(&self, id: &String, message: ChatMessage) {
if let Some(sender) = self.get_sender(id) {
if let Err(e) = sender.send(SseEvent::ChatMessage(message)).await {
log::warn!("消息发送失败,session id: {},失败原因:{}", &id, e);
};
// drop(sender);
}
}
}
其中channel使用的是tokio中mpsc的channel。
值得注意的是,new_channel
中,新建连接时,需要向channel发送一条终止事件,确保已有的receiver关闭,返回新的receiver,这一点是用于后续的会话恢复使用。new_channel
会返回receiver和sender,用于消息接收和发送。
当收到模型回复是,调用SsePool::send_message
发送消息到channel,再头通过receiver接收消息,转发到前端。
可以把它初始化到静态变量中,方便全局调用:
static SSE_POOL: LazyLock<SsePool> = LazyLock::new(|| SsePool::init());
于是,我们的接口可以完善为以下内容:
#[post("/chat", data = "<req>")]
async fn question(
req: Json<ChatMessageReq>,
) -> (ContentType, TextStream<impl Stream<Item = String>>)
{
// 请求新消息,并返回receiver
let (_, _, mut receiver) = service::new_message(req).await.unwrap();
let stream = TextStream! {
// 持续接收receiver的消息,然后推送到前端
while let Some(item) = receiver.recv().await {
match item{
//模型回复的消息
SseEvent::ChatMessage(message) => {
yield SseEvent::ChatMessage(message.clone()).to_message();
if SseEvent::is_done(&message) {
// 推送消息
yield SseEvent::Abort.to_message();
break;
}
},
// 关闭通道
SseEvent::Abort => {
yield SseEvent::Abort.to_message();
break;
},
_ => {}
}
}
yield SseEvent::Abort.to_message();
drop(receiver);
};
(ContentType::EventStream, stream)
}
至此,新会话的接口就完成了。
接下来是会话的恢复。
当前端切换会话或刷新页面时,我们希望能够继续收到未回复完成的消息,所以需要一个用于会话恢复的接口。同样的,接口需要会话ID来区分恢复哪一个会话。
#[post("/resume-stream", data = "<req>")]
async fn resume_stream(
req: Json<ResumeStreamReq>,
) -> (ContentType, TextStream<impl Stream<Item = String>>)
{
// 会话ID
let session_id = req.session_id.clone();
let stream = TextStream! {
// 尝试恢复会话,并返回receiver,如果能够返回receiver说明会话未完成,否则已经完成
if let Some(mut receiver) = service::resume_stream(&req.session_id)
.await
.unwrap()
{
// 持续接收未回复完成的消息
while let Some(item) = receiver.recv().await {
match item {
// 模型回复的消息
SseEvent::ChatMessage(message) => {
yield SseEvent::ChatMessage(message.clone()).to_message();
if SseEvent::is_done(&message) {
yield SseEvent::Abort.to_message();
break;
}
}
// 关闭通道
SseEvent::Abort => {
yield SseEvent::Abort.to_message();
break;
}
_ => {}
}
}
drop(receiver);
}
yield SseEvent::Abort.to_message();
};
(ContentType::EventStream, stream)
}
在service::resume_stream
中,首先检查对应会话ID的channel是否存在,存在则新建channel返回receiver,不存在则表明已经回复完成。
pub async fn resume_stream(
session_id: &String,
) -> AppResult<Option<Receiver<SseEvent>>> {
if let None = chat::get_connection(session_id) {
return Ok(None);
}
// 获取会话对应的channel,如果channel存在则标识消息仍在回复中
let (_, receiver) = chat::new_connection(session_id);
Ok(Some(receiver))
}
至此,便实现了会话恢复,刷新页面后仍能后接收strem消息。
总结
使用Rust写这些业务代码的速度,终归是没有Java快,一些常用的库,没有Java系列封装的简单易用,不过应用占用资源确实比Java小很多。
本次使用的一些库:
- tokio:异步运行环境,以及mpsc的channel,
- dashmap:支持并发的hashmap,但是使用不当容易造成死锁。