Baseline2链接;基于星火大模型的群聊对话分角色要素提取挑战-baseline2 - 飞桨AI Studio星河社区 (baidu.com)
数据集再构造
因为数据集中的对话数据还是标离散的,我们可以使用大模型自己先对数据集做一次抽取,构建一个新的数据集
这里对原群聊对话设计了一个总结Prompt,目的是将原始对话内容进行精简。方便做微调数据。
一方面直接将群聊对话作为数据集的话,会导致上下文过长,超过限制。还有上下文太长会导致抽取效果变差。
这个prompt相较于原本的baseline区别比较明显,对需要抽取的任务做了一次总结。总结了四个方面:
客户基本信息:需要从中区分出客户角色,并得到客户基本信息,其中包括姓名、手机号码、邮箱、地区、详细地址、性别、年龄和生日 客户意向与预算信息: 客户意向与预算信息包括咨询类型、意向产品、购买异议点、预算是否充足、总体预算金额以及预算明细 客户购买准备情况:户购买准备情况包括竞品信息、客户是否有意向、客户是否有卡点以及客户购买阶段 跟进计划信息: 跟进计划信息包括参与人、时间点和具体事项,这些信息用于指导销售团队在未来的跟进工作中与客户互动
通过总结后的数据一方面节约了微调的运算资源,一方面也让数据被清洗后更容易被模型理解,达到更好的抽取效果。
content = ''
prompt = f'''
你是一个数据分析大师,你需要从群聊对话中进行分析,里面对话的角色中大部分是客服角色,你需要从中区分出有需求的客户,并得到以下四类数据。
****群聊对话****
{content}
****分析数据****
客户基本信息:需要从中区分出客户角色,并得到客户基本信息,其中包括姓名、手机号码、邮箱、地区、详细地址、性别、年龄和生日
客户意向与预算信息: 客户意向与预算信息包括咨询类型、意向产品、购买异议点、预算是否充足、总体预算金额以及预算明细
客户购买准备情况:户购买准备情况包括竞品信息、客户是否有意向、客户是否有卡点以及客户购买阶段
跟进计划信息: 跟进计划信息包括参与人、时间点和具体事项,这些信息用于指导销售团队在未来的跟进工作中与客户互动
****注意****
1.只输出客户基本信息、客户意向与预算信息、客户购买准备情况、跟进计划信息对应的信息,不要输出无关内容
2.不要输出分析内容
3.输出内容格式为md格式
'''
jsonl_data = {"instruction":"假设你是一个智能交互助手,基于用户的输入文本,解析其中语义,抽取关键信息,以json格式生成结构化的语义内容。","input":"请调小空气净化器的湿度到1","output":"{\"intent\":\"CONTROL\",\"slots\":[{\"name\":\"device\",\"normValue\":\"airCleaner\",\"value\":\"空气净化器\"},{\"name\":\"insType\",\"normValue\":\"set\",\"value\":\"调小\"},{\"name\":\"attr\",\"normValue\":\"humidity\",\"value\":\"湿度\"},{\"name\":\"attrValue\",\"normValue\":\"1\",\"value\":\"1\"}],\"sample\":\"请调小空气净化器的湿度到1\"}"}
这里我们通过星火3.5api清洗原来的数据,总结后按照刚才看到得单行jsonl存储格式将数据存入traindata.jsonl中。
# 训练集制作
with open('train.json', 'r', encoding='utf-8') as file:
data = json.load(file)
with open('traindata.jsonl', 'w', encoding='utf-8') as file:
# 训练集行数(130)不符合要求,范围:1500~90000000
# 遍历数据列表,并将每一行写入文件
# 这里为了满足微调需求我们重复12次数据集 130*12=1560
for line_data in tqdm(data):
line_input = line_data["chat_text"]
line_output = line_data["infos"]
content = line_input
res = chatbot(prompt=prompt)
line_write = {
"instruction":jsonl_data["instruction"],
"input":json.dumps(res, ensure_ascii=False),
"output":json.dumps(line_output, ensure_ascii=False)
}
# 因为数据共有130行,为了能满足训练需要的1500条及以上,我们将正常训练数据扩充12倍。
for time in range(12):
file.write(json.dumps(line_write, ensure_ascii=False) + '\n') # '\n' 用于在每行末尾添加换行符
# 验证集制作 CSV: input,target
with open('test_data.json', 'r', encoding='utf-8') as file:
data_test = json.load(file)
with open('test.csv', 'w', newline='', encoding='utf-8') as csvfile:
csvwriter = csv.writer(csvfile)
csvwriter.writerow(["input","target"])
# 遍历数据列表,并将每一行写入CSV文件
for line_data in tqdm(data_test):
content = line_data["chat_text"]
res = chatbot(prompt=prompt)
## 文件内容校验失败: test.jsonl(不含表头起算)第1行的内容不符合规则,限制每组input和target字符数量总和上限为8000,当前行字符数量:10721
line_list = [res, "-"]
csvwriter.writerow(line_list)
运行过程中可能会有些超时,这是正常情况
微调讯飞模型
先上传数据集
https://training.xfyun.cn/overview
测试集也要上传
然后我们利用上传好的数据集部署微调的任务
训练好之后要部署服务
使用微调模型
接着回到baselin2的代码中,填入我们微调好的模型相关接口信息,在SparkApi.py文件的108行,引号中填入你的resourceId
接着运行剩余的代码,将生成的结果上传评测,下图是我提交的效果
https://challenge.xfyun.cn/h5/detail?type=role-element-extraction&ch=dw24_y0SCtd