一,添加依赖
<dependency>
<groupId>com.github.plexpt</groupId>
<artifactId>chatgpt</artifactId>
<version>4.0.7</version>
</dependency>
二,重写SseEmitter 改为UTF-8编码
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.http.server.ServerHttpResponse;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import java.nio.charset.Charset;
/**
* @ClassName: SseEmitterUTF8
* @Description: 重写SseEmitter 改为UTF-8编码
* @Authror: XQD
* @Date: 2023/11/6 18:51
*/
public class SseEmitterUTF8 extends SseEmitter {
public SseEmitterUTF8(Long timeout) {
super(timeout);
}
@Override
protected void extendResponse(ServerHttpResponse outputMessage) {
super.extendResponse(outputMessage);
HttpHeaders headers = outputMessage.getHeaders();
headers.setContentType( new MediaType("text", "event-stream", Charset.forName("UTF-8")));
}
}
三,核心方法
private static final String OPENAI_API_HOST = "chatgpt的接口";
@GetMapping(value = "/v1/stream")
@PermitAll
public SseEmitter streamEvents(@RequestParam String query) {
SseEmitterUTF8 sseEmitter = new SseEmitterUTF8(-1L);
SseStreamListener listener = new SseStreamListener(sseEmitter);
//ConsoleStreamListener listener = new ConsoleStreamListener();
Message message = Message.of(query);
ChatCompletion chatCompletion = ChatCompletion.builder()
.model("llama2")
.messages(Arrays.asList(message))
.stream(true)
.temperature(1)
.topP(1)
.presencePenalty(0)
.frequencyPenalty(0)
.maxTokens(250)
.build();
// 不需要代理的话,注销此行
//Proxy proxy = Proxys.http("192.168.1.98", 7890);
ChatGPTStream chatGPTStream = ChatGPTStream.builder()
.timeout(600)
.apiKey("empty")
//.proxy(proxy)
.apiHost(OPENAI_API_HOST)
.build()
.init();
chatGPTStream.streamChatCompletion(chatCompletion, listener);
listener.setOnComplate(msg -> {
//回答完成,可以做一些事情
//sseEmitter.complete();
System.out.println(msg);
});
return sseEmitter;
}
可以在浏览器测试