基于星火大模型的群聊对话分角色要素提取挑战赛-Lora微调与prompt构造

news2024/11/26 5:54:20

赛题连接

https://challenge.xfyun.cn/topic/info?type=role-element-extraction&option=phb

数据集预处理

由于赛题官方限定使用了星火大模型,所以只能调用星火大模型的API或者使用零代码微调
首先训练数据很少是有129条,其中只有chat_textinfos两个属性,chat_text是聊天文本,infos就是提取的信息也是训练集标签,他的平均长度有6000左右对于星火对于信息提取任务已经很长了,而且最长的将近30000,如果使用星火大模型进行询问肯定是要被截断的,而且微调上传的数据也是有最大长度的,我门需要对数据进行处理。
请添加图片描述

数据简单清洗

简单的导包

from dataclasses import dataclass
from sparkai.llm.llm import ChatSparkLLM, ChunkPrintHandler
from sparkai.core.messages import ChatMessage
import pandas as pd
import os
import json
import re
import matplotlib.pyplot as plt
from tqdm import tqdm
from math import ceil
import numpy as np
from copy import deepcopy
import random

tqdm.pandas()
plt.rcParams['font.family'] = ['STFangsong']
plt.rcParams['axes.unicode_minus'] = False

加载数据

data_dir = "./data"
train_file = "train.json"
test_file = "test_data.json"

train_data = pd.read_json(os.path.join(data_dir, train_file))
test_data =  pd.read_json(os.path.join(data_dir, test_file))

首先我们发现数据集中有许多[图片]超链接,这些对数据提取作用不大,我们可以将其去掉,

# 删除表情图片、超链接
train_data['chat_text'] = train_data['chat_text'].str.replace(r"\[[^\[\]]{2,10}\]", "", regex=True)
train_data['chat_text'] = train_data['chat_text'].str.replace("https?://\S+", "", regex=True)
test_data['chat_text'] = test_data['chat_text'].str.replace(r"\[[^\[\]]{2,10}\]", "", regex=True)
test_data['chat_text'] = test_data['chat_text'].str.replace("https?://\S+", "", regex=True)

对于一个人连续的对话我们可以哦将其合并成一个对话

def get_names_phones_and_emails(example):
    names = re.findall(r"(?:\n)?([\u4e00-\u9fa5]+\d+):", example["chat_text"])
    names += re.findall(r"@([\u4e00-\u9fa5]+)\s", example["chat_text"])
    emails = re.findall(r"[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}", example["chat_text"])
    # phones = re.findall(r"1[356789]\d{9}", example["chat_text"]) # 文本中的手机号并不是标准手机号
    phones = re.findall(r"\d{3}\s*\d{4}\s*\d{4}", example["chat_text"]) 
    return pd.Series([set(names), set(phones), set(emails)], index=['names', 'phones', 'emails'])
    
def merge_chat(example):
    for name in example['names']:
        example["chat_text"] = example["chat_text"].replace(f"\n{name}:", f"<|sep|>{name}:")
    chats = example["chat_text"].split("<|sep|>")
    
    last_name = "UNKNOWN"
    new_chats = []
    for chat in chats:
        if chat.startswith(last_name):
            chat = chat.strip("\n")
            chat = "".join(chat.split(":")[1:])
            new_chats[-1] += " " + chat
        else:
            new_chats.append(chat)
            last_name = chat.split(":")[0]
    return pd.Series(["\n".join(new_chats), new_chats], index=["chats", "chat_list"])

# 使用正则表达式获得'names', 'phones', 'emails'
train_data[['names', 'phones', 'emails']] = train_data.apply(get_names_phones_and_emails, axis=1)
test_data[['names', 'phones', 'emails']] = test_data.apply(get_names_phones_and_emails, axis=1)
# 分割聊天记录, 合并连续相同人的聊天
train_data[["chats", "chat_list"]] = train_data.apply(merge_chat, axis=1)
test_data[["chats", "chat_list"]] = test_data.apply(merge_chat, axis=1)

请添加图片描述

补充

补充:后面我们发现数据中chat_text中有许多是重复多编的,我们需要把重复的也给去除掉,这样处理后的数据就会大大减小,使用暴力匹配去除
请添加图片描述

def process(excemple):
    chat_list = excemple["chat_text"].split("\n")

    res = []
    s = 0
    while s < len(chat_list):
        
        i, j = s, s+1
        start_j = j
        while i < len(chat_list) and j < len(chat_list):
            if chat_list[i] == chat_list[j]:
                i += 1
            else:
                if i != s:
                    if j - start_j >10:
                        res += list(range(start_j, j))
                    i = s
                start_j = j
            j += 1
        s += 1
    texts = []
    for i in range(len(chat_list)):
        if i not in res:
            texts.append(chat_list[i])
    return "\n".join(texts)
                    

train_data["chat_text"] = train_data.apply(process, axis = 1)
test_data["chat_text"] = test_data.apply(process, axis = 1)

构造训练集

处理之后其实有些还是很长,我们可以有两种简单粗暴的方法

  1. 截断
  2. 分块
    对于构造训练数据,我们使用了第一种截断的方法,但这两种方法都有一定的缺点
    我们需要查看讯飞官方微调需要的训练集格式,这里我选择使用JSONL格式,并且其每一行是一个JSON字符串,格式为
{"input":"", "target":""}

在这里插入图片描述
训练时我选用了讯飞的spark pro进行训练,其要求训练数据不少于1500条,每一个input+target长度不能大于8000

def process(x):
	# 提示词,我们交代清楚大模型的角色、目标、注意事项,然后提供背景信息,输出格式就可以了
    prompt = f"""Instruction:
你是一个信息要素提取工作人员,你需要从给定的`ChatText`中提取出**客户**的`Infos`中相关信息,将提取的信息填到`Infos`中,
注意事项:
1. 没有的信息无需填写
2. 保持`Infos`的JSON格式不变,没有的信息项也要保留!!!
4. 姓名可以是聊天昵称
5. 注意是客户的信息,不是客服的信息
6. 可以有多个客户信息
ChatText:
{x["chat_text"]}
"""
	# 要求的输出格式
    infos = """"
Infos:
infos": [{
    "基本信息-姓名": "",
    "基本信息-手机号码": "",
    "基本信息-邮箱": "",
    "基本信息-地区": "",
    "基本信息-详细地址": "",
    "基本信息-性别": "",
    "基本信息-年龄": "",
    "基本信息-生日": "",
    "咨询类型": [],
    "意向产品": [],
    "购买异议点": [],
    "客户预算-预算是否充足": "",
    "客户预算-总体预算金额": "",
    "客户预算-预算明细": "",
    "竞品信息": "",
    "客户是否有意向": "",
    "客户是否有卡点": "",
    "客户购买阶段": "",
    "下一步跟进计划-参与人": [],
    "下一步跟进计划-时间点": "",
    "下一步跟进计划-具体事项": ""
}]
"""
	# prompt+infos是文件中的input,answer是文件中的target
    answer = f"""{x["infos"]}""" #target
    total= len(prompt + infos + answer)
    if total > 8000:
        prompt = prompt[:8000-len(infos + answer)]
    return pd.Series([prompt, answer], index=["input", "target"])

data = train_data.apply(process, axis=1)
# 测试集中的target并没有用可以忽略
data = test_data.apply(process, axis=1)

#保存数据
with open(os.path.join(data_dir, "my_train.jsonl"), "w", encoding="utf-8") as f:
    f.write("\n".join([json.dumps(i, ensure_ascii=False) for i in list(data.transpose().to_dict().values())]))
f.close()
with open(os.path.join(data_dir, "my_test.jsonl"), "w", encoding="utf-8") as f:
    f.write("\n".join([json.dumps(i, ensure_ascii=False) for i in list(data.transpose().to_dict().values())]))
f.close()

对于训练数据不少于1500条的要求,我直接将训练集进行了多次复制,只要不少于1500条就可以训练。训练我只训练了两轮。
在这里插入图片描述

使用官方零代码微调

在这里插入图片描述

测试

模型训练好后我们需要到官网将训练好的模型发布,这样才能够调用
在这里插入图片描述
在这里插入图片描述
在我的服务中获取 接口地址APPIDAPIKeyAPISecret,不同版本会有不同

在这里插入图片描述
后续就可以写代码测试了,我们可以询问多轮然后进行投票,减少一次不确定性带来的误差,一轮其实已经可以达到26以上的分数了

from sparkai.llm.llm import ChatSparkLLM, ChunkPrintHandler
from sparkai.core.messages import ChatMessage
import pandas as pd
import os
from tqdm import tqdm
import json


spark = ChatSparkLLM(
    spark_api_url="wss://spark-api-n.xf-yun.com/v3.1/chat",#spark pro微调的url
    spark_app_id="",
    spark_api_key="",
    spark_api_secret="",
    spark_llm_domain="patchv3", #spark pro微调的版本
    streaming=False,
)
def save_result(data):
    with open("./data/result1.json", "w") as f:
        file = data.to_json(orient='records', index=False, force_ascii=False)
        f.write(file)
    f.close()
for j in range(0, 10):
    res = []
    for i in tqdm(range(len(data)), desc=f"正在询问第{j}轮"):
        messages = [ChatMessage(
            role="user",
            content=data.iloc[i]["input"]
        )]
        while True:
            try:
                handler = ChunkPrintHandler()
                a = spark.generate([messages], callbacks=[handler])
                a = json.loads(a.generations[0][0].text.replace("'", "\""))
            except:
                print("出错了")
                continue
            res.append(a)
            break
    multi_res.append(res)
    test_data[f"infos_{j}"] = res
    save_result(test_data)

多轮投票

from typing import Counter, defaultdict


template_infos = {
    "基本信息-姓名": "",
    "基本信息-手机号码": "",
    "基本信息-邮箱": "",
    "基本信息-地区": "",
    "基本信息-详细地址": "",
    "基本信息-性别": "",
    "基本信息-年龄": "",
    "基本信息-生日": "",
    "咨询类型": [],
    "意向产品": [],
    "购买异议点": [],
    "客户预算-预算是否充足": "",
    "客户预算-总体预算金额": "",
    "客户预算-预算明细": "",
    "竞品信息": "",
    "客户是否有意向": "",
    "客户是否有卡点": "",
    "客户购买阶段": "",
    "下一步跟进计划-参与人": [],
    "下一步跟进计划-时间点": "",
    "下一步跟进计划-具体事项": ""
}
result_Infos = []
## 这里的代码已经不是我最初始的代码了,可能会影响到效果,最初我是不管有结果个用户,只投出一个用户,其他信息也是直接全部投票,没有使用根据'基本信息-姓名'进行分开投票,可以自行尝试,投票还是可以提升一点分数的
for multi_infos in zip(*multi_res):
    names_info_dict = defaultdict(list)
    for infos in multi_infos:
        for info in infos:
            names_info_dict[info['基本信息-姓名']].append(info)
    res_infos = []
    for name in names_info_dict:
        l = len(names_info_dict[name])
        print(l)
        if l < 5:
            continue
        infos = template_infos.copy()
      
        for attr in template_infos:
            if isinstance(template_infos[attr], str):
                val_freq = Counter([multi_info.get(attr, "") for multi_info in names_info_dict[name]])
                top_2 = val_freq.most_common(2)
                if len(top_2) == 1:
                    val = top_2[0][0]
                else:
                    if top_2[0][0] == "" and top_2[1][1] < l/2:
                        val = ""
                    elif top_2[0][0] == "":
                        val = top_2[1][0]
                    else:
                        val = top_2[0][0] 
            else:
                val_freq = []
                for multi_info in names_info_dict[name]:
                    val_freq.extend((multi_info.get(attr, [])))
                val_freq = Counter(val_freq)
                val =[val for val, freq in val_freq.most_common(10) if freq > l/2]
            infos[attr] = val
        res_infos.append(infos)
        # if len(res_infos) >= 2:
        #     print(len(names_info_dict[name]),res_infos)
    result_Infos.append(res_infos)
test_data["infos"] = result_Infos
save_result(test_data[["chat_text", "infos"]])

总结

以上只是一个简洁的思路,如果有其他想法欢迎在评论区留言。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1878980.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

模版方法模式详解:使用和实现的指南

目录 模版方法模式模版方法模式结构模版方法模式适合应用场景模版方法模式优缺点练手题目题目描述输入描述输出描述题解 模版方法模式 模板方法模式是一种行为设计模式&#xff0c; 它在超类中定义了一个算法的框架&#xff0c; 允许子类在不修改结构的情况下重写算法的特定步…

游戏推荐: 植物大战僵尸杂交版

下载地址网上一搜就有. 安装就能玩. 2是显血. 4显示植物血, 5是加速. 都是左手主键盘的按钮, 再按是取消. 比较刺激: ps: 设置里面还能打开自动收集阳光和金币.

Elasticsearch (1):ES基本概念和原理简单介绍

Elasticsearch&#xff08;简称 ES&#xff09;是一款基于 Apache Lucene 的分布式搜索和分析引擎。随着业务的发展&#xff0c;系统中的数据量不断增长&#xff0c;传统的关系型数据库在处理大量模糊查询时效率低下。因此&#xff0c;ES 作为一种高效、灵活和可扩展的全文检索…

斜率优化DP——AcWing 303. 运输小猫

斜率优化DP 定义 斜率优化DP&#xff08;Slope Optimization Dynamic Programming&#xff09;是一种高级动态规划技巧&#xff0c;用于优化具有特定形式的状态转移方程。它主要应用于那些状态转移涉及求极值&#xff08;如最小值或最大值&#xff09;的问题中&#xff0c;通…

加密与安全_三种方式实现基于国密非对称加密算法的加解密和签名验签

文章目录 国际算法基础概念常见的加密算法及分类签名和验签基础概念常见的签名算法应用场景 国密算法对称加密&#xff08;DES/AES⇒SM4&#xff09;非对称加密&#xff08;RSA/ECC⇒SM2&#xff09;散列(摘要/哈希)算法&#xff08;MD5/SHA⇒SM3&#xff09; Code方式一 使用B…

每日算法-插值查找

1.概念 插值查找是一种改良版的二分查找,其优势在于,对于较为均匀分布的有序数列,能够更快地使得mid中间游标快速接近目标值. 2.计算公式 中间游标计算公式. 公式说明: 公式的主要思路是,以第一次定位mid中间游标为例, 在接近平均分配的情况下,左右游标之间的差值表示总计供…

Linux线程同步【拿命推荐版】

目录 &#x1f6a9;引言 &#x1f6a9;听故事&#xff0c;引概念 &#x1f6a9;生产者消费者模型 &#x1f680;再次理解生产消费模型 &#x1f680;挖掘特点 &#x1f6a9;条件变量 &#x1f680;条件变量常用接口 &#x1f680;条件变量的原理 &#x1f6a9;引言 上一篇…

新的特性使得数据处理更加直观本教程将带你逐步了解如何使用Java Stream API

本人详解 作者:王文峰,参加过 CSDN 2020年度博客之星,《Java王大师王天师》 公众号:JAVA开发王大师,专注于天道酬勤的 Java 开发问题中国国学、传统文化和代码爱好者的程序人生,期待你的关注和支持!本人外号:神秘小峯 山峯 转载说明:务必注明来源(注明:作者:王文峰…

暑假集中备考2024年汉字小达人:来做18道历年选择题备考吧

结合最近几年的活动安排&#xff0c;预计2024年第11届汉字小达人比赛还有4个多月就启动&#xff0c;那么孩子们如何利用这段时间有条不紊地准备汉字小达人比赛呢&#xff1f; 我的建议是充分利用即将到来的暑假&#xff1a;①把小学1-5年级的语文课本上的知识点熟悉&#xff0…

[数据集][目标检测]围栏破损检测数据集VOC+YOLO格式1196张1类别

数据集格式&#xff1a;Pascal VOC格式YOLO格式(不包含分割路径的txt文件&#xff0c;仅仅包含jpg图片以及对应的VOC格式xml文件和yolo格式txt文件) 图片数量(jpg文件个数)&#xff1a;1196 标注数量(xml文件个数)&#xff1a;1196 标注数量(txt文件个数)&#xff1a;1196 标注…

一篇就够了,为你答疑解惑:锂电池一阶模型-离线参数辨识(附代码)

锂电池一阶模型-参数离线辨识 背景模型简介数据收集1. 最大可用容量实验2. 开路电压实验3. 混合动力脉冲特性实验离线辨识对应模型对应代码总结下期预告文章字数有点多,耐心不够的谨慎点击阅读。 下期继续讲解在线参数辨识方法。 背景 最近又在开始重新梳理锂电池建模仿真与S…

Spring底层原理之bean的加载方式八 BeanDefinitionRegistryPostProcessor注解

BeanDefinitionRegistryPostProcessor注解 这种方式和第七种比较像 要实现两个方法 第一个方法是实现工厂 第二个方法叫后处理bean注册 package com.bigdata1421.bean;import org.springframework.beans.BeansException; import org.springframework.beans.factory.config.…

wordpress企业主题和wordpress免费主题

农业畜牧养殖wordpress主题 简洁大气的农业畜牧养殖wordpress主题&#xff0c;农业农村现代化&#xff0c;离不开新农人、新技术。 https://www.jianzhanpress.com/?p3051 SEO优化wordpress主题 简洁的SEO优化wordpress主题&#xff0c;效果好不好&#xff0c;结果会告诉你…

天气网站爬虫及可视化

摘要&#xff1a;随着互联网的快速发展&#xff0c;人们对天气信息的需求也越来越高。本论文基于Python语言&#xff0c;设计并实现了一个天气网站爬虫及可视化系统。该系统通过网络爬虫技术从多个天气网站上获取实时的天气数据&#xff0c;并将数据进行清洗和存储。同时&#…

Halcon 椭圆

一 椭圆 方差的概念: 例1 两人的5次测验成绩如下&#xff1a;X&#xff1a; 50&#xff0c;100&#xff0c;100&#xff0c;60&#xff0c;50 E(X)72&#xff1b;Y&#xff1a; 73&#xff0c; 70&#xff0c; 75&#xff0c;72&#xff0c;70 E(Y)72。平均成绩相同&#xff0c…

idea 用久了代码提示变慢卡顿优化

idea 用久了代码提示变慢卡顿优化 修改虚拟机配置 修改编译构建堆内存

【proteus经典实战】16X192点阵程序

一、简介 6X192点阵程序通常用于表示高分辨率图像或文字&#xff0c;其中16X表示像素阵列的宽度&#xff0c;192表示每个像素阵列中的点阵数&#xff0c;16X192点阵程序需要一定的编程知识和技能才能编写和调试&#xff0c;同时还需要考虑硬件设备的兼容性和性能等因素。 初始…

智能交通(2)——IntelliLight智能交通灯

论文分享&#xff1a;IntelliLight | Proceedings of the 24th ACM SIGKDD International Conference on Knowledge Discovery & Data Mininghttps://dl.acm.org/doi/10.1145/3219819.3220096摘要 智能交通灯控制对于高效的交通系统至关重要。目前现有的交通信号灯大多由手…

共模和差模的基本概念

电压电流在导体或导线中传播时&#xff0c;存在两种工作形态&#xff1a;共模和差模。电子设备的信号线在进行相互通信时&#xff0c;至少会存在两根导线以形成电传输回路&#xff0c;除此之外&#xff0c;通常还存在第三个导体&#xff0c;即“参考地”。当信号正常传输时&…

JAVA学习笔记-JAVA基础语法-DAY19-File类、递归

第一章 File类 1.1 概述 java.io.File 类是文件和目录路径名的抽象表示&#xff0c;主要用于文件和目录的创建、查找和删除等操作。 1.2 构造方法 public File(String pathname) &#xff1a;通过将给定的路径名字符串转换为抽象路径名来创建新的 File实例。public File(St…