文章目录
- 准备工作
- 整体思路
- 接入大模型
- 服务端和大模型连接
- 客户端和服务端的连接
- 测试
准备工作
-
到讯飞星火大模型上根据官方的提示申请tokens
申请成功后可以获得对应的secret,key还有之前创建的应用的appId,这些就是我们要用到的信息 -
搭建项目
整体思路
考虑到敏感信息等安全性问题,这里和大模型的交互都放到后端去做。
客户端,服务端,星火大模型均通过Websocket的方式建立连接,用户询问问题时向SpringBoot服务端发送消息,服务端接收到后,创建与星火大模型的连接,并访问大模型,获取到请求结果后发送给客户端
如果想实现根据上下文问答,就要把历史问题和历史回答结果全部传回大模型服务端
请求参数的构建和响应参数解析参照官方文档Web API文档
接入大模型
服务端和大模型连接
/**
* 与大模型建立Socket连接
*
* @author gwj
*/
@Slf4j
public class BigModelNew extends WebSocketListener {
public static final String appid = "appid";
// 对话历史存储集合
public static Map<Long,List<RoleContent>> hisMap = new ConcurrentHashMap<>();
public static String totalAnswer = ""; // 大模型的答案汇总
private static String newAsk = "";
public static synchronized void ask(String question) {
newAsk = question;
}
public static final Gson gson = new Gson();
// 项目中需要用到的参数
private Long userId;
private Boolean wsCloseFlag;
// 构造函数
public BigModelNew(Long userId, Boolean wsCloseFlag) {
this.userId = userId;
this.wsCloseFlag = wsCloseFlag;
}
// 由于历史记录最大上线1.2W左右,需要判断是能能加入历史
public boolean canAddHistory() {
int len = 0;
List<RoleContent> list = hisMap.get(userId);
for (RoleContent temp : list) {
len = len + temp.getContent().length();
}
if (len > 12000) {
list.remove(0);
list.remove(1);
list.remove(2);
list.remove(3);
list.remove(4);
return false;
} else {
return true;
}
}
// 线程来发送参数
class ModelThread extends Thread {
private WebSocket webSocket;
private Long userId;
public ModelThread(WebSocket webSocket, Long userId) {
this.webSocket = webSocket;
this.userId = userId;
}
public void run() {
try {
JSONObject requestJson = new JSONObject();
JSONObject header = new JSONObject(); // header参数
header.put("app_id", appid);
header.put("uid", userId+UUID.randomUUID().toString().substring(0,16));
JSONObject parameter = new JSONObject(); // parameter参数
JSONObject chat = new JSONObject();
chat.put("domain", "4.0Ultra");
chat.put("temperature", 0.5);
chat.put("max_tokens", 4096);
parameter.put("chat", chat);
JSONObject payload = new JSONObject(); // payload参数
JSONObject message = new JSONObject();
JSONArray text = new JSONArray();
// 历史问题获取
List<RoleContent> list = hisMap.get(userId);
if (list != null && !list.isEmpty()) {
//log.info("his:{}",list);
for (RoleContent tempRoleContent : list) {
text.add(JSON.toJSON(tempRoleContent));
}
}
// 最新问题
RoleContent roleContent = new RoleContent();
roleContent.setRole("user");
roleContent.setContent(newAsk);
text.add(JSON.toJSON(roleContent));
hisMap.computeIfAbsent(userId, k -> new ArrayList<>());
hisMap.get(userId).add(roleContent);
message.put("text", text);
payload.put("message", message);
requestJson.put("header", header);
requestJson.put("parameter", parameter);
requestJson.put("payload", payload);
// System.out.println(requestJson);
webSocket.send(requestJson.toString());
// 等待服务端返回完毕后关闭
while (true) {
// System.err.println(wsCloseFlag + "---");
Thread.sleep(200);
if (wsCloseFlag) {
break;
}
}
webSocket.close(1000, "");
} catch (Exception e) {
log.error("【大模型】发送消息错误,{}",e.getMessage());
}
}
}
@Override
public void onOpen(WebSocket webSocket, Response response) {
super.onOpen(webSocket, response);
log.info("上线");
ModelThread modelThread = new ModelThread(webSocket,userId);
modelThread.start();
}
@Override
public void onMessage(WebSocket webSocket, String text) {
JsonParse json = gson.fromJson(text, JsonParse.class);
if (json.getHeader().getCode() != 0) {
log.error("发生错误,错误码为:{} sid为:{}", json.getHeader().getCode(),json.getHeader().getSid());
//System.out.println(json);
webSocket.close(1000, "");
}
List<Text> textList = json.getPayload().getChoices().getText();
for (Text temp : textList) {
// 向客户端发送回答信息,如有存储问答需求,在此处存储
ModelChatEndpoint.sendMsgByUserId(userId,temp.getContent());
totalAnswer = totalAnswer + temp.getContent();
}
if (json.getHeader().getStatus() == 2) {
// 可以关闭连接,释放资源
if (canAddHistory()) {
RoleContent roleContent = new RoleContent();
roleContent.setRole("assistant");
roleContent.setContent(totalAnswer);
hisMap.get(userId).add(roleContent);
} else {
hisMap.get(userId).remove(0);
RoleContent roleContent = new RoleContent();
roleContent.setRole("assistant");
roleContent.setContent(totalAnswer);
hisMap.get(userId).add(roleContent);
}
//收到响应后让等待的线程停止等待
wsCloseFlag = true;
}
}
@Override
public void onFailure(WebSocket webSocket, Throwable t, Response response) {
super.onFailure(webSocket, t, response);
try {
if (null != response) {
int code = response.code();
System.out.println("onFailure code:" + code);
System.out.println("onFailure body:" + response.body().string());
if (101 != code) {
System.out.println("connection failed");
System.exit(0);
}
}
} catch (IOException e) {
e.printStackTrace();
}
}
// 鉴权方法
public static String getAuthUrl(String hostUrl, String apiKey, String apiSecret) throws Exception {
URL url = new URL(hostUrl);
// 时间
SimpleDateFormat format = new SimpleDateFormat("EEE, dd MMM yyyy HH:mm:ss z", Locale.US);
format.setTimeZone(TimeZone.getTimeZone("GMT"));
String date = format.format(new Date());
// 拼接
String preStr = "host: " + url.getHost() + "\n" + "date: " + date + "\n" + "GET " + url.getPath() + " HTTP/1.1";
// System.err.println(preStr);
// SHA256加密
Mac mac = Mac.getInstance("hmacsha256");
SecretKeySpec spec = new SecretKeySpec(apiSecret.getBytes(StandardCharsets.UTF_8), "hmacsha256");
mac.init(spec);
byte[] hexDigits = mac.doFinal(preStr.getBytes(StandardCharsets.UTF_8));
// Base64加密
String sha = Base64.getEncoder().encodeToString(hexDigits);
// System.err.println(sha);
// 拼接
String authorization = String.format("api_key=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"", apiKey, "hmac-sha256", "host date request-line", sha);
// 拼接地址
HttpUrl httpUrl = Objects.requireNonNull(HttpUrl.parse("https://" + url.getHost() + url.getPath())).newBuilder().//
addQueryParameter("authorization", Base64.getEncoder().encodeToString(authorization.getBytes(StandardCharsets.UTF_8))).//
addQueryParameter("date", date).//
addQueryParameter("host", url.getHost()).//
build();
return httpUrl.toString();
}
}
其中用来接收响应参数相关实体类
@Data
public class JsonParse {
Header header;
Payload payload;
}
@Data
public class Header {
int code;
int status;
String sid;
}
@Data
public class Payload {
Choices choices;
}
@Data
public class Choices {
List<Text> text;
}
@Data
public class Text {
String role;
String content;
}
@Data
public class RoleContent {
String role;
String content;
}
客户端和服务端的连接
/**
* 接收客户端请求
*
* @author gwj
* @date 2024/10/29 16:51
*/
@ServerEndpoint(value = "/ws/model", configurator = GetUserConfigurator.class)
@Component
@Slf4j
public class ModelChatEndpoint {
private static AtomicInteger online = new AtomicInteger(0);
private static final ConcurrentHashMap<Long,ModelChatEndpoint> wsMap = new ConcurrentHashMap<>();
private static BigModelConfig config;
@Resource
private BigModelConfig modelConfig;
@PostConstruct
public void init() {
config = modelConfig;
}
private Session session;
private Long userId;
@OnOpen
public void onOpen(EndpointConfig config, Session session) {
String s = config.getUserProperties().get("id").toString();
userId = Long.parseLong(s);
this.session = session;
wsMap.put(userId,this);
online.incrementAndGet();
log.info("用户{},连接成功,在线人数:{}",userId,online);
}
@OnClose
public void onClose() {
wsMap.remove(userId);
online.incrementAndGet();
log.info("{},退出,在线人数:{}",userId,online);
}
@OnError
public void onError(Session session, Throwable error) {
log.error("连接出错,{}", error.getMessage());
}
@OnMessage
public void onMessage(String message,Session session) throws Exception {
BigModelNew.ask(message);
//构建鉴权url
String authUrl = BigModelNew.getAuthUrl(config.getHostUrl(), config.getApiKey(), config.getApiSecret());
OkHttpClient client = new OkHttpClient.Builder().build();
String url = authUrl.replace("http://", "ws://").replace("https://", "wss://");
Request request = new Request.Builder().url(url).build();
WebSocket webSocket = client.newWebSocket(request,
new BigModelNew(this.userId, false));
log.info("收到客户端{}的消息:{}", userId, message);
}
private void sendMsg(String message) {
try {
this.session.getBasicRemote().sendText(message);
} catch (IOException e) {
log.error("客户端{}发送{}失败",userId,message);
}
}
/**
* 根据userId向用户发送消息
*
* @param userId 用户id
* @param message 消息
*/
public static void sendMsgByUserId(Long userId,String message) {
if (userId != null && wsMap.containsKey(userId)) {
wsMap.get(userId).sendMsg(message);
}
}
}
测试
这样就简单实现了一个ai问答功能