WebSocket使得客户端和服务器之间的数据交换变得更加简单,允许服务端主动向客户端推送数据。在WebSocket API中,浏览器和服务器只需要完成一次握手,两者之间就可以创建持久性的连接,并进行双向数据传输。
一、为什么需要 WebSocket?
初次接触 WebSocket 的人,都会问同样的问题:我们已经有了 HTTP 协议,为什么还需要另一个协议?它能带来什么好处?
答案很简单,因为 HTTP 协议有一个缺陷:通信只能由客户端发起。
举例来说,我们想了解今天的天气,只能是客户端向服务器发出请求,服务器返回查询结果。HTTP 协议做不到服务器主动向客户端推送信息。
这种单向请求的特点,注定了如果服务器有连续的状态变化,客户端要获知就非常麻烦。我们只能使用"轮询":每隔一段时候,就发出一个询问,了解服务器有没有新的信息。最典型的场景就是聊天室。
轮询的效率低,非常浪费资源(因为必须不停连接,或者 HTTP 连接始终打开)。因此,工程师们一直在思考,有没有更好的方法。WebSocket 就是这样发明的。
二、简介
WebSocket 协议在2008年诞生,2011年成为国际标准。所有浏览器都已经支持了。
它的最大特点就是,服务器可以主动向客户端推送信息,客户端也可以主动向服务器发送信息,是真正的双向平等对话,属于服务器推送技术的一种。
其他特点包括:
(1)建立在 TCP 协议之上,服务器端的实现比较容易。
(2)与 HTTP 协议有着良好的兼容性。默认端口也是80和443,并且握手阶段采用 HTTP 协议,因此握手时不容易屏蔽,能通过各种 HTTP 代理服务器。
(3)数据格式比较轻量,性能开销小,通信高效。
(4)可以发送文本,也可以发送二进制数据。
(5)没有同源限制,客户端可以与任意服务器通信。
(6)协议标识符是ws(如果加密,则为wss),服务器网址就是 URL。
服务端的实现
依赖spring-boot-starter-websocket
模块实现WebSocket实时对话交互。
CustomTextWebSocketHandler
,扩展的TextWebSocketHandler
import cn.hutool.json.JSONUtil;
import lombok.extern.slf4j.Slf4j;
import org.springframework.util.StringUtils;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.PongMessage;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketMessage;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.handler.TextWebSocketHandler;
import java.util.concurrent.CountDownLatch;
/**
* 文本处理器
*
* @see org.springframework.web.socket.handler.TextWebSocketHandler
*/
@Slf4j
public class CustomTextWebSocketHandler extends TextWebSocketHandler {
/**
* 第三方身份,消息身份
*/
private String thirdPartyId;
/**
* 回复消息内容
*/
private String replyContent;
private StringBuilder replyContentBuilder;
/**
* 完成信号
*/
private final CountDownLatch doneSignal;
public CustomTextWebSocketHandler(CountDownLatch doneSignal) {
this.doneSignal = doneSignal;
}
public String getThirdPartyId() {
return thirdPartyId;
}
public String getReplyContent() {
return replyContent;
}
@Override
public void afterConnectionEstablished(WebSocketSession session) throws Exception {
log.info("connection established, session={}", session);
replyContentBuilder = new StringBuilder(16);
// super.afterConnectionEstablished(session);
}
@Override
public void handleMessage(WebSocketSession session, WebSocketMessage<?> message) throws Exception {
super.handleMessage(session, message);
}
/**
* 消息已接收完毕("stop")
*/
private static final String MESSAGE_DONE = "[DONE]";
@Override
protected void handleTextMessage(WebSocketSession session, TextMessage message) throws Exception {
// super.handleTextMessage(session, message);
String payload = message.getPayload();
log.info("payload={}", payload);
OpenAiReplyResponse replyResponse = Jsons.fromJson(payload, OpenAiReplyResponse.class);
if (replyResponse != null && replyResponse.isSuccess()) {
String msg = replyResponse.getMsg();
if (Strings.isEmpty(msg)) {
return;
} else if (msg.startsWith("【超出最大单次回复字数】")) {
// {"msg":"【超出最大单次回复字数】该提示由GPT官方返回,非我司限制,请缩减回复字数","code":1,
// "extParam":"{\"chatId\":\"10056:8889007174\",\"requestId\":\"b6af5830a5a64fa8a4ca9451d7cb5f6f\",\"bizId\":\"\"}",
// "id":"chatcmpl-7LThw6J9KmBUOcwK1SSOvdBP2vK9w"}
return;
} else if (msg.startsWith("发送内容包含敏感词")) {
// {"msg":"发送内容包含敏感词,请修改后重试。不合规汇如下:炸弹","code":1,
// "extParam":"{\"chatId\":\"10024:8889006970\",\"requestId\":\"828068d945c8415d8f32598ef6ef4ad6\",\"bizId\":\"430\"}",
// "id":"4d4106c3-f7d4-4393-8cce-a32766d43f8b"}
matchSensitiveWords = msg;
// 请求完成
doneSignal.countDown();
return;
} else if (MESSAGE_DONE.equals(msg)) {
// 消息已接收完毕
replyContent = replyContentBuilder.toString();
thirdPartyId = replyResponse.getId();
// 请求完成
doneSignal.countDown();
log.info("replyContent={}", replyContent);
return;
}
replyContentBuilder.append(msg);
}
}
@Override
protected void handlePongMessage(WebSocketSession session, PongMessage message) throws Exception {
super.handlePongMessage(session, message);
}
@Override
public void handleTransportError(WebSocketSession session, Throwable exception) throws Exception {
replyContentBuilder = null;
log.info("handle transport error, session={}", session, exception);
doneSignal.countDown();
// super.handleTransportError(session, exception);
}
@Override
public void afterConnectionClosed(WebSocketSession session, CloseStatus status) throws Exception {
replyContentBuilder = null;
log.info("connection closed, session={}, status={}", session, status);
if (status == CloseStatus.NORMAL) {
log.error("connection closed fail, session={}, status={}", session, status);
}
doneSignal.countDown();
// super.afterConnectionClosed(session, status);
}
}
OpenAiHandler
/**
* OpenAI处理器
*/
public interface OpenAiHandler<Req, Rsp> {
/**
* 请求前置处理
*
* @param req 入参
*/
default void beforeRequest(Req req) {
//
}
/**
* 响应后置处理
*
* @param req 入参
* @param rsp 出参
*/
default void afterResponse(Req req, Rsp rsp) {
//
}
}
OpenAiService
/**
* OpenAI服务
* <pre>
* API reference introduction
* https://platform.openai.com/docs/api-reference/introduction
* </pre>
*/
public interface OpenAiService<Req, Rsp> extends OpenAiHandler<Req, Rsp> {
/**
* 补全指令
*
* @param req 入参
* @return 出参
*/
default Rsp completions(Req req) {
beforeRequest(req);
Rsp rsp = doCompletions(req);
afterResponse(req, rsp);
return rsp;
}
/**
* 操作补全指令
*
* @param req 入参
* @return 出参
*/
Rsp doCompletions(Req req);
}
OpenAiServiceImpl
import cn.hutool.core.util.StrUtil;
import lombok.extern.slf4j.Slf4j;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.annotation.Configuration;
import org.springframework.stereotype.Service;
import org.springframework.util.StopWatch;
import org.springframework.util.concurrent.ListenableFuture;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.client.WebSocketClient;
import javax.annotation.Nullable;
import java.io.IOException;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
/**
* OpenAI服务实现
*/
@Slf4j
@Configuration(proxyBeanMethods = false)
@EnableConfigurationProperties(OpenAiProperties.class)
@Service("openAiService")
public class OpenAiServiceImpl implements OpenAiService<CompletionReq, CompletionRsp> {
private final OpenAiProperties properties;
/**
* 套接字客户端
*/
private final WebSocketClient webSocketClient;
/**
* 模型请求记录服务
*/
private final ModelRequestRecordService modelRequestRecordService;
private static final String THREAD_NAME_PREFIX = "gpt.openai";
public OpenAiServiceImpl(
OpenAiProperties properties,
ModelRequestRecordService modelRequestRecordService
) {
this.properties = properties;
this.modelRequestRecordService = modelRequestRecordService;
webSocketClient = WebSocketUtil.applyWebSocketClient(THREAD_NAME_PREFIX);
log.info("create OpenAiServiceImpl instance");
}
@Override
public void beforeRequest(CompletionReq req) {
// 请求身份
if (Strings.isEmpty(req.getRequestId())) {
req.setRequestId(UuidUtil.getUuid());
}
}
@Override
public void afterResponse(CompletionReq req, CompletionRsp rsp) {
if (rsp == null || Strings.isEmpty(rsp.getReplyContent())) {
return;
}
// 三方敏感词检测
String matchSensitiveWords = rsp.getMatchSensitiveWords();
if (Strings.isNotEmpty(matchSensitiveWords)) {
// 敏感词命中
rsp.setMatchSensitiveWords(matchSensitiveWords);
return;
}
// 阶段任务耗时统计
StopWatch stopWatch = new StopWatch(req.getRequestId());
try {
// 敏感词检测
stopWatch.start("checkSensitiveWord");
String replyContent = rsp.getReplyContent();
// ApiResult<String> apiResult = checkMsg(replyContent, false);
// stopWatch.stop();
// if (!apiResult.isSuccess() && Strings.isNotEmpty(apiResult.getData())) {
// // 敏感词命中
// rsp.setMatchSensitiveWords(apiResult.getData());
// return;
// }
// 记录落库
stopWatch.start("saveModelRequestRecord");
ModelRequestRecord entity = applyModelRequestRecord(req, rsp);
modelRequestRecordService.save(entity);
} finally {
if (stopWatch.isRunning()) {
stopWatch.stop();
}
log.info("afterResponse execute time, {}", stopWatch);
}
}
private static ModelRequestRecord applyModelRequestRecord(
CompletionReq req, CompletionRsp rsp) {
Long orgId = req.getOrgId();
Long userId = req.getUserId();
String chatId = applyChatId(orgId, userId);
return new ModelRequestRecord()
.setOrgId(orgId)
.setUserId(userId)
.setModelType(req.getModelType())
.setRequestId(req.getRequestId())
.setBizId(req.getBizId())
.setChatId(chatId)
.setThirdPartyId(rsp.getThirdPartyId())
.setInputMessage(req.getMessage())
.setReplyContent(rsp.getReplyContent());
}
private static String applyChatId(Long orgId, Long userId) {
return orgId + ":" + userId;
}
private static String applySessionId(String appId, String chatId) {
return appId + '_' + chatId;
}
private static final String URI_TEMPLATE = "wss://socket.******.com/websocket/{sessionId}";
@Nullable
@Override
public CompletionRsp doCompletions(CompletionReq req) {
// 阶段任务耗时统计
StopWatch stopWatch = new StopWatch(req.getRequestId());
stopWatch.start("doHandshake");
// 闭锁,相当于一扇门(同步工具类)
CountDownLatch doneSignal = new CountDownLatch(1);
CustomTextWebSocketHandler webSocketHandler = new CustomTextWebSocketHandler(doneSignal);
String chatId = applyChatId(req.getOrgId(), req.getUserId());
String sessionId = applySessionId(properties.getAppId(), chatId);
ListenableFuture<WebSocketSession> listenableFuture = webSocketClient
.doHandshake(webSocketHandler, URI_TEMPLATE, sessionId);
stopWatch.stop();
stopWatch.start("getWebSocketSession");
long connectionTimeout = properties.getConnectionTimeout().getSeconds();
try (WebSocketSession webSocketSession = listenableFuture.get(connectionTimeout, TimeUnit.SECONDS)) {
stopWatch.stop();
stopWatch.start("sendMessage");
OpenAiParam param = applyParam(chatId, req);
webSocketSession.sendMessage(new TextMessage(Jsons.toJson(param)));
long requestTimeout = properties.getRequestTimeout().getSeconds();
// wait for all to finish
boolean await = doneSignal.await(requestTimeout, TimeUnit.SECONDS);
if (!await) {
log.error("await doneSignal fail, req={}", req);
}
String replyContent = webSocketHandler.getReplyContent();
String matchSensitiveWords = webSocketHandler.getMatchSensitiveWords();
if (Strings.isEmpty(replyContent) && Strings.isEmpty(matchSensitiveWords)) {
// 消息回复异常
return null;
}
String delimiters = properties.getDelimiters();
replyContent = StrUtil.replaceFirst(replyContent, delimiters, "");
replyContent = StrUtil.replaceLast(replyContent, delimiters, "");
String thirdPartyId = webSocketHandler.getThirdPartyId();
return new CompletionRsp()
.setThirdPartyId(thirdPartyId)
.setReplyContent(replyContent)
.setMatchSensitiveWords(matchSensitiveWords);
} catch (InterruptedException | ExecutionException | TimeoutException e) {
log.error("get WebSocketSession fail, req={}", req, e);
} catch (IOException e) {
log.error("sendMessage fail, req={}", req, e);
} finally {
if (stopWatch.isRunning()) {
stopWatch.stop();
}
log.info("doCompletions execute time, {}", stopWatch);
}
return null;
}
private static final int MIN_TOKENS = 11;
/**
* 限制单次最大回复单词数(tokens)
*/
private static int applyMaxTokens(int reqMaxTokens, int maxTokensConfig) {
if (reqMaxTokens < MIN_TOKENS || maxTokensConfig < reqMaxTokens) {
return maxTokensConfig;
}
return reqMaxTokens;
}
private OpenAiParam applyParam(String chatId, CompletionReq req) {
OpenAiDataExtParam extParam = new OpenAiDataExtParam()
.setChatId(chatId)
.setRequestId(req.getRequestId())
.setBizId(req.getBizId());
// 提示
String prompt = req.getPrompt();
// 分隔符
String delimiters = properties.getDelimiters();
String message = prompt + delimiters + req.getMessage() + delimiters;
int maxTokens = applyMaxTokens(req.getMaxTokens(), properties.getMaxTokens());
OpenAiData data = new OpenAiData()
.setMsg(message)
.setContext(properties.getContext())
.setLimitTokens(maxTokens)
.setExtParam(extParam);
String sign = OpenAiUtil.applySign(message, properties.getSecret());
return new OpenAiParam()
.setData(data)
.setSign(sign);
}
}
WebSocketUtil
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
import org.springframework.util.StringUtils;
import org.springframework.web.socket.client.WebSocketClient;
import org.springframework.web.socket.client.standard.StandardWebSocketClient;
/**
* WebSocket辅助方法
*/
public final class WebSocketUtil {
/**
* 创建一个新的WebSocket客户端
*/
public static WebSocketClient applyWebSocketClient(String threadNamePrefix) {
StandardWebSocketClient webSocketClient = new StandardWebSocketClient();
int cpuNum = Runtime.getRuntime().availableProcessors();
ThreadPoolTaskExecutor taskExecutor = new ThreadPoolTaskExecutor();
taskExecutor.setCorePoolSize(cpuNum);
taskExecutor.setMaxPoolSize(200);
taskExecutor.setDaemon(true);
if (StringUtils.hasText(threadNamePrefix)) {
taskExecutor.setThreadNamePrefix(threadNamePrefix);
} else {
taskExecutor.setThreadNamePrefix("gpt.web.socket");
}
taskExecutor.initialize();
webSocketClient.setTaskExecutor(taskExecutor);
return webSocketClient;
}
}
OpenAiUtil
import org.springframework.util.DigestUtils;
import java.nio.charset.StandardCharsets;
/**
* OpenAi辅助方法
*/
public final class OpenAiUtil {
/**
* 对消息内容进行md5加密
*
* @param message 消息内容
* @param secret 加签密钥
* @return 十六进制加密后的消息内容
*/
public static String applySign(String message, String secret) {
String data = message + secret;
byte[] dataBytes = data.getBytes(StandardCharsets.UTF_8);
return DigestUtils.md5DigestAsHex(dataBytes);
}
}
参考资料
- WebSocket - 维基百科
- WebSocket 教程 - 阮一峰
- 使用WebSocket - 廖雪峰
- WebSocket Support - Spring Framework
- Messaging WebSockets - Spring Boot
- Create WebSocket Endpoints Using @ServerEndpoint - “How-to” Guides - Spring Boot