之前我们通过阿里提供的cloud ai对接了通义千问。cloud ai对接通义千问
那么接下来我们尝试一些别的模型看一下,其实这个文章主要是表达一种对接方式,其他的都大同小异。都可以依此方法进行处理。
一、明确模型参数
本次我们对接的理论支持来自于阿里云提供的文档。阿里云大3-6b模型文档
我们看到他其实支持多种调用方式,包括sdk和http,我本人是不喜欢sdk的,因为会有冲突或者版本之类的问题,不如直接调用三方,把问题都扔到三方侧。所以我们这里来展示一下使用http的调用方式。
而且大模型的chat一般都是流式的,非流式的没啥技术含量而且效果很low。所以我们直接参考这部分内容即可,
我们看到他们的服务端其实是支持SSE的推流方式的,具体SSE是啥可以自行百度。
而流式和非流式的区别就在于请求参数的设置。如果你配置了,那大模型端就会给你按照流式响应。
在有了以上理论支持之后,我们就来测试一下。
二、代码接入
我们看到他的示例请求参数为:
curl --location 'https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation' \
--header 'Authorization: Bearer <YOUR-DASHSCOPE-API-KEY>' \ # 这里写你的appkey
--header 'Content-Type: application/json' \
--header 'X-DashScope-SSE: enable' \ # 开启流式
--data '{
"model": "chatglm3-6b", # 模型名字
"input":{
"messages":[
{
"role": "user",
"content": "你好,请介绍一下故宫"
}
]
},
"parameters": {
"result_format": "message"
}
}'
所以我们可以找到关键点就在以上三处,至于如何申请appkey,可以参考官方。
那么我们接下来就使用okhttp这种支持事件响应的来对接流式的输出。
1、编写返回内容反序列化类
首先我们先来处理返回格式,我决定用一个java类来接受,具体你觉得不灵活可以直接用Json,怎么弄都行。
我们来看一下官网的响应示例格式。
{"output":{"choices":[{"message":{"content":"\n 故宫是中国北京市中心的一座明清两代的皇宫,现已成为博物馆。故宫是中国最具代表性的古建筑之一,也是世界文化遗产之一,以其丰富的文化遗产和精美的建筑艺术而闻名于世界。故宫占地面积达72万平方米,拥有9000多间房屋和70多座建筑,由大小湖泊、宫殿、花园和殿堂组成,是中国古代宫殿建筑之精华。","role":"assistant"},"finish_reason":"stop"}]},"usage":{"total_tokens":105,"input_tokens":24,"output_tokens":81},"request_id":"9d970376-4ba3-98b8-8387-f95702280341"}
我们看到他是个字符串,然后在流式的最后一句他的finish_reason的值是stop,这时候我们就可以结束推流。
OK,我们就来接收一下。
import lombok.Data;
@Data
public class Chatglm36bResponse {
private Output output;
private Usage usage;
private String requestId;
@Data
public static class Output {
private Choice[] choices;
@Data
public static class Choice {
private Message message;
private String finishReason;
@Data
public static class Message {
private String content;
private String role;
}
}
}
@Data
public static class Usage {
private int totalTokens;
private int inputTokens;
private int outputTokens;
}
}
2、编写event事件监听器
import com.alibaba.fastjson.JSONObject;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import okhttp3.Response;
import okhttp3.sse.EventSource;
import okhttp3.sse.EventSourceListener;
import java.io.IOException;
@EqualsAndHashCode(callSuper = true)
@Data
@Slf4j
@NoArgsConstructor
@AllArgsConstructor
public class ChatEventSourceListener extends EventSourceListener {
private String clientId;
@Override
public void onOpen(EventSource eventSource, Response response) {
log.info("ChatEventSourceListener onOpen invoke");
super.onOpen(eventSource, response);
}
@Override
public void onEvent(EventSource eventSource, String id, String type, String data) {
log.info("ChatEventSourceListener onEvent invoke");
Chatglm36bResponse chatglm36bResponse = JSONObject.parseObject(data, Chatglm36bResponse.class);
Chatglm36bResponse.Output output = chatglm36bResponse.getOutput();
Chatglm36bResponse.Output.Choice[] choices = output.getChoices();
for (Chatglm36bResponse.Output.Choice choice : choices) {
String finishReason = choice.getFinishReason();
String content = choice.getMessage().getContent();
log.info("ChatEventSourceListener onEvent finishReason is:{},content is:{}", finishReason, content);
try {
// 给前端推流,前端有组件可以接收这种流。
SseEmitterUtils.sendMsg(clientId, content);
} catch (IOException e) {
throw new RuntimeException(e);
}
// 结束了,取消事件,并且结束SSE推流
if ("stop".equals(finishReason)) {
eventSource.cancel();
SseEmitterUtils.completeDelay(clientId);
}
}
super.onEvent(eventSource, id, type, data);
}
@Override
public void onClosed(EventSource eventSource) {
log.info("ChatEventSourceListener onClosed invoke ******");
super.onClosed(eventSource);
}
@Override
public void onFailure(EventSource eventSource, Throwable t, Response response) {
super.onFailure(eventSource, t, response);
String message = response.message();
response.close();
log.info("ChatEventSourceListener onFailure invoke ****** Throwable is:{},res is {}", t.getMessage(),message);
}
}
我们在每一类事件里面都做了相应的处理。
与之配套的是一个SSE的工具类。
package com.yxy.springbootdemo.utils.sse;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import java.io.IOException;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.*;
public class SseEmitterUtils {
private static final Logger logger = LoggerFactory.getLogger(SseEmitterUtils.class);
private static final ThreadPoolExecutor ssePool = new ThreadPoolExecutor(
20,
200,
30,
TimeUnit.SECONDS,
new LinkedBlockingQueue<>(1000),
runnable -> new Thread(runnable, "sse-sendMsg-pool"),
new ThreadPoolExecutor.AbortPolicy()
);
// SSE连接关闭延迟时间
private static final Integer EMITTER_COMPLETE_DELAY_MILLISECONDS = 5000;
// SSE连接初始化超时时间
private static final Long EMITTER_TIME_OUT_MILLISECONDS = 600_000L;
// 缓存 SSE连接
private static final Map<String, SseEmitter> SSE_CACHE = new ConcurrentHashMap<>();
/**
* 获取 SSE连接 默认超时时间EMITTER_TIME_OUT_MILLISECONDS 毫秒
*
* @param clientId 客户端 ID
* @return 连接对象
*/
public static SseEmitter getConnection(String clientId) {
return getConnection(clientId,EMITTER_TIME_OUT_MILLISECONDS);
}
/**
* 获取 SSE连接
*
* @param clientId 客户端 ID
* @param timeout 连接超时时间,单位毫秒
* @return 连接对象
*/
public static SseEmitter getConnection(String clientId,Long timeout) {
final SseEmitter sseEmitter = SSE_CACHE.get(clientId);
if (Objects.nonNull(sseEmitter)) {
return sseEmitter;
} else {
final SseEmitter emitter = new SseEmitter(timeout);
// 初始化emitter回调
initSseEmitter(emitter, clientId);
// 连接建立后,将连接放入缓存
SSE_CACHE.put(clientId, emitter);
logger.info("[SseEmitter] 连接已建立,clientId = {}", clientId);
return emitter;
}
}
/**
* 关闭指定的流连接
*
* @param clientId 客户端 ID
*/
public static void closeConnection(String clientId) {
final SseEmitter sseEmitter = SSE_CACHE.get(clientId);
logger.info("[流式响应-停止生成] 收到客户端关闭连接指令,Emitter is {},clientId = {}", null == sseEmitter ? "NOT-Exist" : "Exist", clientId);
if (Objects.nonNull(sseEmitter)) {
SSE_CACHE.remove(clientId);
sseEmitter.complete();
}
try {
TimeUnit.MILLISECONDS.sleep(EMITTER_COMPLETE_DELAY_MILLISECONDS);
} catch (InterruptedException ex) {
logger.error("流式响应异常", ex);
Thread.currentThread().interrupt();
}
}
/**
* 推送消息
*
* @param clientId 客户端 ID
* @param msg 消息
* @return 连接是否存在
* @throws IOException IO异常
*/
public static boolean sendMsg(String clientId, String msg) throws IOException {
final SseEmitter sseEmitter = SSE_CACHE.get(clientId);
if (Objects.nonNull(sseEmitter)) {
try {
sseEmitter.send(msg);
} catch (Exception e) {
logger.error("[流式响应-停止生成] ");
return true;
}
return false;
} else {
return true;
}
}
/**
* 异步推送消息 TODO 目前未实现提供回调
*
* @param clientId 客户端 ID
* @param msg 消息
* @return 连接是否存在
* @throws IOException IO异常
*/
public static boolean sendMsgAsync(String clientId, String msg){
final SseEmitter sseEmitter = SSE_CACHE.get(clientId);
if (Objects.nonNull(sseEmitter)) {
try {
ssePool.submit(()->{
try {
sseEmitter.send(msg);
} catch (IOException e) {
logger.error("[流式响应-停止生成] ");
}
});
} catch (Exception e) {
logger.error("[流式响应-停止生成] ");
return true;
}
return false;
} else {
return true;
}
}
/**
* 立即关闭SseEmitter,可能存在推流不完全的情况,谨慎使用
*
* @param clientId
*/
public static void complete(String clientId) {
completeDelay(clientId,0);
}
/**
* 延迟关闭 SseEmitter,延迟一定时长时为了尽量保证最后一次推送数据被前端完整接收
*
* @param clientId 客户端ID
*/
public static void completeDelay(String clientId) {
completeDelay(clientId,EMITTER_COMPLETE_DELAY_MILLISECONDS);
}
/**
* 延迟关闭 SseEmitter,延迟指定时长时为了尽量保证最后一次推送数据被前端完整接收
*
* @param clientId 客户端ID
*/
public static void completeDelay(String clientId,Integer delayMilliSeconds) {
final SseEmitter sseEmitter = SSE_CACHE.get(clientId);
if (Objects.nonNull(sseEmitter)) {
try {
TimeUnit.MILLISECONDS.sleep(delayMilliSeconds);
sseEmitter.complete();
} catch (InterruptedException ex) {
logger.error("流式响应异常", ex);
Thread.currentThread().interrupt();
}
}
}
/**
* 初始化 SSE连接 设置一些属性和回调之类的
*
* @param emitter 连接对象
* @param clientId 客户端 ID
*/
private static void initSseEmitter(SseEmitter emitter, String clientId){
// 设置SSE的超时回调
emitter.onTimeout(() -> {
logger.info("[SseEmitter] 连接已超时,正准备关闭,clientId = {}", clientId);
SSE_CACHE.remove(clientId);
});
// 设置SSE的结束回调
emitter.onCompletion(() -> {
logger.info("[SseEmitter] 连接已释放,clientId = {}", clientId);
SSE_CACHE.remove(clientId);
});
// 设置SSE的异常回调
emitter.onError(throwable -> {
logger.error("[SseEmitter] 连接已异常,正准备关闭,clientId = {}", clientId);
SSE_CACHE.remove(clientId);
});
}
}
3、编写调用接口
import com.alibaba.fastjson.JSONArray;
import com.alibaba.fastjson.JSONObject;
import okhttp3.MediaType;
import okhttp3.OkHttpClient;
import okhttp3.Request;
import okhttp3.RequestBody;
import okhttp3.sse.EventSource;
import okhttp3.sse.EventSourceListener;
import okhttp3.sse.EventSources;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import java.util.concurrent.CompletableFuture;
@RestController
@RequestMapping("/chat")
public class StreamChatController {
@PostMapping("/send")
public SseEmitter sendMessage(@RequestParam String username, @RequestParam String message) {
SseEmitter sseEmitter = SseEmitterUtils.getConnection(username);
CompletableFuture.runAsync(()->send(username,message));
return sseEmitter;
}
public void send(String username,String message){
OkHttpClient client = new OkHttpClient();
JSONObject inputJson = new JSONObject();
JSONArray messagesArray = new JSONArray();
JSONObject systemMessage = new JSONObject();
systemMessage.put("role", "system");
systemMessage.put("content", "You are a helpful assistant.");
messagesArray.add(systemMessage);
JSONObject userMessage = new JSONObject();
userMessage.put("role", "user");
userMessage.put("content", message);
messagesArray.add(userMessage);
inputJson.put("messages", messagesArray);
JSONObject payloadJson = new JSONObject();
payloadJson.put("model", "chatglm3-6b");
payloadJson.put("input", inputJson);
JSONObject parametersJson = new JSONObject();
parametersJson.put("result_format", "message");
payloadJson.put("parameters", parametersJson);
String json = payloadJson.toString();
RequestBody body = RequestBody.create(MediaType.parse("application/json"),json);
Request request = new Request.Builder()
.url("https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation")
.post(body)
.addHeader("Authorization", "Bearer 你得API-KEY")
.addHeader("Content-Type", "application/json")
.addHeader("X-DashScope-SSE", "enable")
.build();
// 创建事件监听器
EventSourceListener eventSourceListener = new ChatEventSourceListener(username);
EventSource.Factory factory = EventSources.createFactory(client);
// 创建事件
EventSource eventSource = factory.newEventSource(request, eventSourceListener);
// 与服务器建立连接
eventSource.request();
}
}
4、编写前端
我这个有点粗糙,实际效果比这好的多。
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<title>SSE Chat</title>
</head>
<body>
<h1>YXY-Chat</h1>
<div id="chat-messages"></div>
<form id="message-form">
<input type="text" id="message-input" placeholder="输入消息">
<button type="submit">发送</button>
</form>
<script>
const chatMessages = document.getElementById('chat-messages');
const messageForm = document.getElementById('message-form');
const messageInput = document.getElementById('message-input');
// 连接到聊天室
const connectToChat = () => {
const username = prompt('Enter your username:');
const eventSource = new EventSource(`/chat/connect?username=${encodeURIComponent(username)}`);
// 接收来自服务器的消息
eventSource.onmessage = function(event) {
const message = event.data;
displayMessage(message);
};
// 处理连接错误
eventSource.onerror = function(event) {
console.error('EventSource error:', event);
eventSource.close();
};
// 提交消息表单
messageForm.addEventListener('submit', function(event) {
event.preventDefault();
const message = messageInput.value.trim();
if (message !== '') {
sendMessage(username, message);
messageInput.value = '';
}
});
};
// 发送消息到服务器
const sendMessage = (username, message) => {
fetch(`/chat/send?username=${encodeURIComponent(username)}&message=${encodeURIComponent(message)}`, {
method: 'POST'
})
.catch(error => console.error('Error sending message:', error));
};
// 在界面上显示消息
const displayMessage = (message) => {
const messageElement = document.createElement('div');
messageElement.textContent = message;
chatMessages.appendChild(messageElement);
};
// 发起连接
connectToChat();
</script>
</body>
</html>
5、发起调用
我们看到其实是成功了,但是前端没有把流数据渲染上去,我不太懂前端,后面改一改试试。
三、总结
我们这只是其中一种模型的对接,其实别的也都差不多,都是基于流可以用http来操作,你可以在你的项目中建立一个AI中台,来对接各种模型,给别的服务提供调用。只是需要看明白每种模型的参数。
而且我们目前只是简单的实现,还存在很多问题,比如okhttp客户端没有做池化,每次都是new出来的。
CompletableFuture的异步调用没有指定线程池,还是共用的默认池,这样会导致可能被别的业务影响。
等等细节问题,我们这里先不做处理,后面如果真的要用,可以着手细节处的优化。