我这里集成后,做成了一个工具jar包,如果有不同方式的,欢迎大家讨论,共同进步。
集成限制:
1、灵积模型有QPM(QPS)限制,每个模型不一样,需要根据每个模型适配
集成开发思路:
因有QPS限制,无法支持多任务并发执行,所以使用任务池操作,定时监听任务池中任务状态;
因系统中执行不能等待QPS释放后执行,故使用异步调用;
开发思路:
1、创建任务,提交到任务池中
2、任务监听器每10秒检查任务池中的任务执行情况:
1)任务未执行:获任务token,获取到执行任务,否则不执行
2)任务执行中:判断任务执行是否超时,如果超时,重置任务状态,重试计数加1
3)任务执行失败:执行失败回调。从任务池中清除
4)任务执行成功:从任务池中清除
3、任务执行:
1)获取任务token,如果获取到就执行,否则不执行
2)利用工具类请求灵积模型
3)判断任务执行状态:成功:执行成功回调;失败:重试计数加1,重置任务状态
4)归还token
集成编码
1、前置操作
详见阿里云灵积模型服务开发参开https://help.aliyun.com/zh/dashscope/developer-reference/acquisition-and-configuration-of-api-key?spm=a2c4g.11186623.0.0.1403193eLiHQfl
开发参考中获取到的API-KEY需要写到项目的配置文件中
2、创建灵积服务jar(aliyun-dashscope)
按照灵积模型Java jdk最佳实践的方式实现集成模型灵积模型Java jdk最佳实践https://help.aliyun.com/zh/dashscope/java-sdk-best-practices?spm=a2c4g.11186623.0.0.4da417d9T9NKfMpom文件中引入jar
<dependency>
<groupId>com.alibaba</groupId>
<artifactId>dashscope-sdk-java</artifactId>
<version>2.15.0</version>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-pool2</artifactId>
</dependency>
<dependency>
<groupId>com.aa.bb</groupId>
<artifactId>common-redis</artifactId>
<version>1.0.0</version>
</dependency>
dashscope-sdk-java : 灵积服务模型jar
commons-pool2 : 对象池工具jar
common-redis :个人项目中redis工具包(可以自己封装一个)
3、编码
1)创建config
@Data
@Configuration
@ConfigurationProperties(prefix = "aliyun.dashscope")
public class DashScopeConfig {
/**
* api密钥
*/
@Value("${aliyun.dashscope.apiKey}")
private String apiKey;
/**
* 最大tokens数
*/
private int maxTokens = 800;
/**
* 模型
*/
private String model = "qwen-plus";
/**
* QPS
*/
private int qps = 15;
/**
* qps缓存密钥
*/
private String qpsRedisKey = "aliyun:dashscope:token";
/**
* 尝试计数
*/
private int tryCount = 3;
/**
* task间隔时间
*/
private int time = 10000;
}
2)创建对象池工厂
public class DashScopePoolFactory extends BasePooledObjectFactory<Generation> {
@Override
public Generation create() throws Exception {
return new Generation();
}
@Override
public PooledObject<Generation> wrap(Generation generation) {
return new DefaultPooledObject<>(generation);
}
}
3)创建task
DashTask:任务类
@Data
@Slf4j
public class DashTask {
/**
* qps令牌
*/
private Long qpsToken;
/**
* 正在执行
*/
private boolean execute = false;
/**
* 成功
*/
private boolean success = false;
/**
* 尝试计数
*/
private int tryCount = 0;
/**
* 生成参数
*/
private GenerationParam generationParam;
/**
* 结果
*/
private Message result;
/**
* 成功回调
*/
private Consumer<DashTask> successCallback;
/**
* 失败回调
*/
private Consumer<DashTask> failCallback;
public void setSuccess(boolean success) {
if (success) {
this.onSuccess();
} else {
this.onFail();
}
}
/**
* 论成功
*/
public void onSuccess() {
this.success = true;
try {
if (this.successCallback != null) {
this.successCallback.accept(this);
}
} catch (Exception ex) {
log.error("dash task onSuccess error:" + ex.getMessage());
}
}
/**
* 失败
*/
public void onFail() {
this.success = false;
try {
if (this.failCallback != null) {
this.failCallback.accept(this);
}
} catch (Exception ex) {
log.error("spark task onFail error:" + ex.getMessage());
}
}
}
DashListener:任务监听类
@Slf4j
public class DashListener extends Listener {
public DashListener(long interval) {
super(interval, "dash-listener");
}
@Override
public void run() {
log.info("灵积服务(通义千问)任务监听 start");
setExecute(true);
while (isExecute()) {
try {
DashScopeUtils.asyncTaskStart();
Thread.sleep(getInterval());
} catch (Exception e) {
log.error("灵积服务(通义千问)任务监听 error", e);
}
}
}
}
4)创建工具类
DashScopeUtils:灵积模型基础工具类
@Slf4j
public class DashScopeUtils {
private static volatile DashScopeConfig config;
private static volatile RedisService redisService;
/**
* 获取令牌
*/
public static final int GET_TOKEN_STATUS = 0;
/**
* 归还令牌
*/
public static final int BACK_TOKEN_STATUS = 1;
private static CopyOnWriteArraySet<DashTask> taskList = new CopyOnWriteArraySet<DashTask>();
/**
* 通用池
*/
private static volatile GenericObjectPool<Generation> pool;
/**
* 创建消息
*
* @param role 角色
* @param content 所容纳之物
* @return {@link Message }
*/
public static Message createMessage(Role role, String content) {
return Message.builder().role(role.getValue()).content(content).build();
}
/**
* 调用服务
*
* @param param param
* @return {@link GenerationResult }
*/
public static GenerationResult call(GenerationParam param) {
try {
if (param.getMaxTokens() == null) {
param.setMaxTokens(getConfig().getMaxTokens());
}
Generation gen = getPool().borrowObject();
GenerationResult call = gen.call(param);
getPool().returnObject(gen);
return call;
} catch (Exception e) {
log.error(e.getMessage(), e);
throw new RuntimeException(e.getMessage());
}
}
/**
* 获取对象池
*
* @return {@link GenericObjectPool }<{@link Generation }>
*/
public static GenericObjectPool<Generation> getPool() {
if (pool == null) {
synchronized (DashScopeUtils.class) {
if (pool == null) {
DashScopePoolFactory poolFactory = new DashScopePoolFactory();
GenericObjectPoolConfig<Generation> config = new GenericObjectPoolConfig<>();
config.setMaxTotal(64);
config.setMaxIdle(64);
config.setMinIdle(64);
Constants.apiKey = getConfig().getApiKey();
pool = new GenericObjectPool<>(poolFactory, config);
}
}
}
return pool;
}
/**
* 获取配置
*
* @return {@link DashScopeConfig }
*/
public static DashScopeConfig getConfig() {
if (config == null) {
synchronized (DashScopeConfig.class) {
if (config == null) {
config = SpringUtils.getBean(DashScopeConfig.class);
}
}
}
return config;
}
/**
* 异步任务启动
*/
public static void asyncTaskStart() {
instanceRedis();
getConfig();
// 令牌数量
int current = 0;
if (redisService.hasKey(config.getQpsRedisKey())) {
current = Integer.parseInt(redisService.get(config.getQpsRedisKey()).toString());
}
if (current > 0) {
String all = config.getQpsRedisKey() + ":*";
int size = redisService.keys(all).size();
if (size < current) {
redisService.decr(config.getQpsRedisKey(), current - size);
}
}
if (!taskList.isEmpty()) {
Iterator<DashTask> iterator = taskList.iterator();
while (iterator.hasNext()) {
DashTask dashTask = iterator.next();
if (dashTask.isExecute()) {
if (!redisService.hasKey(config.getQpsRedisKey() + ":" + dashTask.getQpsToken())) {
dashTask.setExecute(false);
dashTask.setTryCount(dashTask.getTryCount()+1);
}
continue;
} else if (dashTask.isSuccess()) {
taskList.remove(dashTask);
} else if (dashTask.getTryCount() > config.getTryCount()) {
dashTask.setSuccess(false);
taskList.remove(dashTask);
} else if (!asyncTaskStart(dashTask)) {
break;
}
}
}
}
/**
* 提交任务
*
* @param dashTask 短跑任务
*/
public static void submitTask(DashTask dashTask) {
taskList.add(dashTask);
}
/**
* 异步任务启动
*
* @param task 任务
* @return boolean
*/
private static boolean asyncTaskStart(DashTask task) {
if (qpsToken(GET_TOKEN_STATUS, task)) {
AsyncManager.me().execute(() -> {
try {
task.setExecute(true);
GenerationResult call = call(task.getGenerationParam());
task.setResult(call.getOutput().getChoices().get(0).getMessage());
task.setSuccess(true);
} catch (Exception e) {
task.setTryCount(task.getTryCount() + 1);
}
task.setExecute(false);
qpsToken(BACK_TOKEN_STATUS, task);
});
return true;
}
return false;
}
/**
* qps令牌
*
* @param status 地位
* @param task 任务
* @return boolean
*/
private static synchronized boolean qpsToken(int status, DashTask task) {
instanceRedis();
getConfig();
int current = 0;
if (redisService.hasKey(config.getQpsRedisKey())) {
current = Integer.parseInt(redisService.get(config.getQpsRedisKey()).toString());
}
// 获取token
if (status == GET_TOKEN_STATUS) {
if (current < config.getQps()) {
Long incr = redisService.incr(config.getQpsRedisKey());
task.setQpsToken(incr);
redisService.set(config.getQpsRedisKey() + ":" + incr, "1", 1, TimeUnit.MINUTES);
return true;
} else {
return false;
}
} else {
if (current > 0) {
redisService.decr(config.getQpsRedisKey());
}
redisService.del(config.getQpsRedisKey() + ":" + task.getQpsToken());
return true;
}
}
/**
* 实例redis
*
* @return {@link RedisService}
*/
private static RedisService instanceRedis() {
if (redisService == null) {
synchronized (DashScopeUtils.class) {
if (redisService == null) {
redisService = SpringUtils.getBean(RedisService.class);
}
if (redisService == null) {
throw new RuntimeException("redisService is null");
}
}
}
return redisService;
}
}
QiamwenUtils:通义千问工具类
public class QianWenUtils {
/**
* 单轮对话
*
* @param content 内容
* @param success 成功
*/
public static void call(String content, Consumer<Message> success) {
Message message = DashScopeUtils.createMessage(Role.USER, content);
call(Collections.singletonList(message), success);
}
/**
* 多轮对话
*
* @param messages 对话列表
* @return {@link Message }
*/
public static void call(List<Message> messages, Consumer<Message> success) {
try {
GenerationParam param = GenerationParam.builder()
.model(DashScopeUtils.getConfig().getModel())
.messages(messages)
.resultFormat(GenerationParam.ResultFormat.MESSAGE)
.topP(0.8)
.maxTokens(600)
.build();
DashTask dashTask = new DashTask();
dashTask.setGenerationParam(param);
dashTask.setSuccessCallback(dash -> success.accept(dash.getResult()));
DashScopeUtils.submitTask(dashTask);
} catch (Exception e) {
throw new RuntimeException("通义千问失败:" + e.getMessage());
}
}
}
5)创建runner
runner主要作用:
(1)检查配置文件是否正确配置;
(2)启动任务监听器
@Slf4j
@Component
public class DashScopeRunner {
private DashListener dashListener;
@PostConstruct
public void run() {
DashScopeConfig config = DashScopeUtils.getConfig();
if (config == null || ObjectUtil.isEmpty(config.getApiKey())) {
throw new RuntimeException("灵积服务(通义千问)启动失败,请检查配置文件");
} else {
log.info("灵积服务(通义千问)启动");
}
dashListener = new DashListener(config.getTime());
dashListener.start();
}
@PostConstruct
public void shutdown() {
if (dashListener != null) {
dashListener.shutdown();
}
}
}
4、测试
5、踩坑
1)token数量验证:每次开始执行任务池中任务状态检查时,要先检查任务token是否和实际一致,避免实际可用token数不足,导致进入死循环
2)任务池中的数据不能使用缓存(redis)
3)成功和失败回调必须是public
4)使用对象池(GenericObjectPool),借出对象,使用完成后必须归还,否则会出现无法借出的情况
5)config中QPS最好小于15,否则会出现限流情况