【NLP修炼系列之玩转LLM】基于 P-Tuning的高效微调ChatGLM方法

news2024/11/21 1:33:43

引言

上周给大家介绍了另一种基于LORA的高效微调ChatGLM-6B模型的方法。本周分享一下另一种高效的微调方法——P-Tuning v2方法,同时在文章的最后对比一下两种高效微调方法的效果怎么样,只有自己动手做实验了才能很客观的看出哪种方法效果更好,在后面的业务工程上具体的方案选型也才可以有更好的选择。

图片

参考paper:https://arxiv.org/pdf/2110.07602.pdf

一 P-tuningV2概述

P-tuningV2方法是P-tuning方法的改进,主要是基于P-tuning和prefix-tuning技术,引入Deep Prompt Encoding和Multi-task Learning等策略进行优化的。和P-tuning相比改进之后的P-tuning v2可以在不参数量的模型上微调效果达到Fine tuning的水平,而P-tuning只能在参数量达到百亿量级的模型上才会有好的效果。

图片

关于介绍了Prefix-Tuning、P-tuning V1和 V2相关的原理和思路。
并附录一张总结比较全面的思维导图:

图片

二 P-Tuning V2高效微调ChatGLM的步骤

1 项目和环境搭建

这里面项目地址也是官方提供的开源代码:

https://github.com/THUDM/ChatGLM-6B

官方教程文档:
https://github.com/THUDM/ChatGLM-6B/tree/main/ptuning

首先使用git将项目clone到自己的本地或者云服务器里面,微调部分在THUDM/ChatGLM-6B的ptuning目录下面,需要进入ptuning目录,安装项目requirements.txt环境。

cd ptuning

运行微调需要4.27.1版本的transformers。除 ChatGLM-6B 的依赖之外,还需要安装以下依赖

pip install rouge_chinese nltk jieba datasets

2 数据集处理

官方给出的数据处理格式:

ADGEN 数据集为根据输入(content)生成一段广告词(summary)。

{
    "content": "类型#上衣*版型#宽松*版型#显瘦*图案#线条*衣样式#衬衫*衣袖型#泡泡袖*衣款式#抽绳",    
    "summary": "这件衬衫的款式非常的宽松,利落的线条可以很好的隐藏身材上的小缺点,穿在身上有着很好的显瘦效果。领口装饰了一个可爱的抽绳,漂亮的绳结展现出了十足的个性,配合时尚的泡泡袖型,尽显女性甜美可爱的气息。"
}

其实数据格式就是{“content”:“”,“summary”:“”}形式,这里面我没可以处理自己的数据集成这种格式,我的是问答类数据,那么content就是问题,summary就是回答。

图片

用以上数据格式创建自己的train.json和dev.json文件,指定数据文件夹路径后期微调需要改成自己的路径名称。

这里是我的路径:

在这里插入图片描述

这里面给两种数据增强的方法(扩展):

(1)simBert做相似文本生成。

这里给出苏神的代码:

https://github.com/ZhuiyiTechnology/simbert

图片

(2)直接使用ChatGPT生成相似文本。

图片

对比可以很明显的看出GPT3.5生成的相似文本质量会更好,后期如果有这方面需求的可以尝试使用GPT3.5模型接口去做数据扩增的工作。

3 P-Tuning v2微调步骤

(1)修改train.sh文件参数
在这里插入图片描述

修改数据集路径train_file 和validation_file 为自己指定路径名称,更改模型路径model_name_or_path 为自己的模型路径。P-Tuning-v2 方法会冻结全部的模型参数,可通过调整 quantization_bit 来被原始模型的量化等级,不加此选项则为 FP16 精度加载,我使用的是32G V100是完全够用了,就把quantization_bit 4注释了。这里可以根据自己的显卡配置自行设置。

其他参数也是根据自己数据集具体情况来调整合适的值,比如max_source_length 长度根据自己数据content文本长度来定,max_target_length 可以根据自己summary长度大小来设定。

bash train.sh  

开始运行train.sh做微调,这里我自定义的数据集比较少,预计训练时间两三个小时左右。

图片

训练好的模型文件保存在output文件夹下,后期做推理时需要加载。

图片

(2)修改evaluate.sh参数

在这里插入图片描述

和train.sh一样需要修改数据集和模型路径名称。

bash evaluate.sh

图片

模型评估的中间文件predict也保存在output文件夹里面,可以用来评估模型在验证集上的效果。

图片

4 模型推理

模型的推理演示这里面也是使用和LORA一样的官方提供的web_demo2.py界面展示,当然可以运行ptuning路径下的web_demo.sh文件,同样需要改一下模型文件路径。

在这里插入图片描述

对比ptuning路径下的web_demo.py 在web_demo2.py中修改代码:

from transformers import AutoModel, AutoTokenizer, AutoConfig
import streamlit as st
from streamlit_chat import message
import os, sys
import torch


st.set_page_config(
    page_title="ChatGLM-6b 演示",    
    page_icon=":robot:"
)

@st.cache_resource
def get_model():
    tokenizer = AutoTokenizer.from_pretrained("/root/new_datas/chatglm/ChatGLM-6B/model/chatglm-6b", trust_remote_code=True)    
    config = AutoConfig.from_pretrained("/root/new_datas/chatglm/ChatGLM-6B/model/chatglm-6b", trust_remote_code=True)    
    config.pre_seq_len = 128 # 预测时需要模型config中含有 pre_seq_len, 模型才会定义prefix_encoder    
    model = AutoModel.from_pretrained("/root/new_datas/chatglm/ChatGLM-6B/model/chatglm-6b",  config=config, trust_remote_code=True).half().cuda()        
    
    prefix_state_dict = torch.load(os.path.join("/root/new_datas/chatglm/ChatGLM-6B/ptuning/output/adgen-chatglm-6b-pt-128-2e-2/checkpoint-3000", "pytorch_model.bin"))    
    new_prefix_state_dict = {}    
    for k, v in prefix_state_dict.items():   
        if k.startswith("transformer.prefix_encoder."):        
            new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v    
    model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)   
     
    model = model.eval()    
    return tokenizer, model


MAX_TURNS = 20
MAX_BOXES = MAX_TURNS * 2


def predict(input, max_length, top_p, temperature, history=None):
    tokenizer, model = get_model()    
    if history is None:    
        history = []    
        
    with container:    
        if len(history) > 0:      
            if len(history)>MAX_BOXES:            
               history = history[-MAX_TURNS:]            
            for i, (query, response) in enumerate(history):            
               message(query, avatar_style="big-smile", key=str(i) + "_user")                
               message(response, avatar_style="bottts", key=str(i))       
               
        message(input, avatar_style="big-smile", key=str(len(history)) + "_user")        
        st.write("AI正在回复:")        
        with st.empty():         
            for response, history in model.stream_chat(tokenizer, input, history, max_length=max_length, top_p=top_p,                            
                                               temperature=temperature):                
                query, response = history[-1]                
                st.write(response)    

    return history


container = st.container()

# create a prompt text for the text generation
prompt_text = st.text_area(label="用户命令输入",
            height = 100,            
            placeholder="请在这儿输入您的命令")

max_length = st.sidebar.slider(
    'max_length', 0, 4096, 2048, step=1
)
top_p = st.sidebar.slider(
    'top_p', 0.0, 1.0, 0.6, step=0.01
)
temperature = st.sidebar.slider(
    'temperature', 0.0, 1.0, 0.95, step=0.01
)

if 'state' not in st.session_state:
    st.session_state['state'] = []

if st.button("发送", key="predict"):
    with st.spinner("AI正在思考,请稍等........"):    
        # text generation        
        st.session_state["state"] = predict(prompt_text, max_length, top_p, temperature, st.session_state["state"])

这里注意之前在推理过程遇到过两个问题,之前踩了不少坑:

图片

原因是模型不能简单的把web_demo2.py中的模型路径替换,出现这种问题是因为直接修改一下web_demo2.py启动文件中模型路径ptuning/output/adgen-chatglm-6b-pt-8-2e-2 ,这样是有问题的。于是借鉴了ptuning路径下的web_demo.py中加载模型的方式重新加载模型。

另一个是报错

“AttributeError: ‘ChatGLMModel‘ object has no attribute ‘prefix_encoder‘”

图片

该问题主要是加载模型出现了问题,加载模型部分需要添加config配置参数,并且需要给config指定pre_seq_len,预测时需要模型config中含有 pre_seq_len, 模型才会定义prefix_encoder。


 config.pre_seq_len = 128 # 预测时需要模型config中含有 pre_seq_len, 模型才会定义prefix_encoder 
 model = AutoModel.from_pretrained("/root/new_datas/chatglm/ChatGLM-6B/model/chatglm-6b",  config=config, trust_remote_code=True).half().cuda()   

当然直接用我上面提供修改后的的web_demo2.py代码应该是不会出现什么两个问题的,注意路径对应自己的路径名称。

三 两种高效微调方式效果对比

在我自己的相同数据集上面使用LORA和P-Tuning v2两种微调方法,从结果可以明显的看过来P-Tuning v2的效果是要优于LORA效果的。

{"content": "你是谁?","summary":"你好,我是聚名科技的客服小聚,很高兴为你服务"}
{"content": "你们公司地点在哪啊?","summary":"我们公司在安徽省的省会合肥市置地广场A座办2109室"}
{"content": "你们的企业使命是什么?","summary":"聚名科技的企业使命是“创造更多价值,实现更多梦想”。"}
{"content": "你们的企业愿景是什么?","summary":"聚名科技的企业愿景是“创造更有价值的互联网时代"}
{"content": "那你们公司的业务线是啥?","summary":"聚名科技的主要业务是域名服务,是安徽省最大的域名提供商。"}
{"content": "说一下你们公司秉持的的价值观是什么?","summary":"聚名科技的企业价值观是保持奋斗,追求细节,开放创新,结果导向,客户第一"}

下面是LORA高效微调效果:

图片

下面是P-Tuning V2高效微调效果:

图片

四 结束语

本文主要介绍了另一种基于P-tuningV2的高效微调方法,用实验对比了基于LORA的微调方式,效果还是明显要更好一点的,后面在实际业务的技术选型上也计划使用基于P-tuningV2的高效微调方法来微调公司垂直领域的业务。总结一下在做实际工作和学习过程中调研方案一定要多动手去做实验,才能更客观的选择更好的方案,当然多看理论也是很好处的,之前也看到了很多P-tuningV2的缺点,相对于LORA微调,前者在大模型微调过程中出现的知识遗忘问题要更严重,后期还要多做实验去验证这个问题。

如何学习大模型

现在社会上大模型越来越普及了,已经有很多人都想往这里面扎,但是却找不到适合的方法去学习。

作为一名资深码农,初入大模型时也吃了很多亏,踩了无数坑。现在我想把我的经验和知识分享给你们,帮助你们学习AI大模型,能够解决你们学习中的困难。

我已将重要的AI大模型资料包括市面上AI大模型各大白皮书、AGI大模型系统学习路线、AI大模型视频教程、实战学习,等录播视频免费分享出来,需要的小伙伴可以扫取。

一、AGI大模型系统学习路线

很多人学习大模型的时候没有方向,东学一点西学一点,像只无头苍蝇乱撞,我下面分享的这个学习路线希望能够帮助到你们学习AI大模型。

在这里插入图片描述

二、AI大模型视频教程

在这里插入图片描述

三、AI大模型各大学习书籍

在这里插入图片描述

四、AI大模型各大场景实战案例

在这里插入图片描述

五、结束语

学习AI大模型是当前科技发展的趋势,它不仅能够为我们提供更多的机会和挑战,还能够让我们更好地理解和应用人工智能技术。通过学习AI大模型,我们可以深入了解深度学习、神经网络等核心概念,并将其应用于自然语言处理、计算机视觉、语音识别等领域。同时,掌握AI大模型还能够为我们的职业发展增添竞争力,成为未来技术领域的领导者。

再者,学习AI大模型也能为我们自己创造更多的价值,提供更多的岗位以及副业创收,让自己的生活更上一层楼。

因此,学习AI大模型是一项有前景且值得投入的时间和精力的重要选择。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2177011.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

超详细超实用!!!AI编程之cursor编写设计模式迪米特法则实例(八)

云风网 云风笔记 云风知识库 一、设计模式迪米特法则定义 只与你的直接朋友交谈,不跟“陌生人”说话 其含义是:如果两个软件实体无须直接通信,那么就不应当发生直接的相互调用,可以通过第三方转发该调用。其目的是降低类之间的耦…

leetcode面试题 03.04. 化栈为队

实现一个MyQueue类,该类用两个栈来实现一个队列。 示例: MyQueue queue new MyQueue();queue.push(1); queue.push(2); queue.peek(); // 返回 1 queue.pop(); // 返回 1 queue.empty(); // 返回 false 说明: 你只能使用标准的栈操作 -…

SpringBoot学习笔记(2)

1.静态文件访问 使用IDEA创建Spring Boot项目,会默认创建出classpath:/static/目录,静态资源一般放在这个目录下即可。 如果默认的静态资源过滤策略不能满足开发需求,也可以自定义静态资源过滤策略。 1.1直接访问 在application.properties中…

在线远程考试|基于springBoot的在线远程考试系统设计与实现(附项目源码+论文+数据库)

私信或留言即免费送开题报告和任务书(可指定任意题目) 目录 一、摘要 二、相关技术 三、系统设计 四、数据库设计 五、核心代码 六、论文参考 七、源码获取 一、摘要 信息数据从传统到当代,是一直在变革当中,突…

增强免疫力的9种食物,秋冬尤其要多吃,营养美味又健康!

随着秋风渐起,冬日的脚步也越来越近!这时候,咱们的身体可是需要更多的关爱和呵护。说到秋冬养生,增强免疫力是头等大事。今天就来跟大家聊聊,那些既营养美味,又能帮我们提升免疫力的9种超级食物&#xff0c…

关于将inet引入的相关问题

🏆本文收录于《全栈Bug调优(实战版)》专栏,主要记录项目实战过程中所遇到的Bug或因后果及提供真实有效的解决方案,希望能够助你一臂之力,帮你早日登顶实现财富自由🚀;同时,欢迎大家关注&&am…

Android 简单实现联系人列表+字母索引效果

效果如上图。 Main Ideas 左右两个列表左列表展示人员数据,含有姓氏首字母的 header item右列表是一个全由姓氏首字母组成的索引列表,点击某个item,展示一个气泡组件(它会自动延时关闭), 左列表滚动并显示与点击的索引列表item …

UDS_1_基础知识

一. 概述 什么是UDS UDS: Unified Diagnostic Service, 统一诊断服务。 UDS是一个在整个汽车系统上经常使用的设备维护协议。其主要遵循:ISO-15765、ISO-14229 等协议。经常应用在整车的各种ECU上面。是一个在整车ECU应用层开发常用的协议之一。 UDS用途: 可以通过诊断…

【Python基础(二)】面向对象

学习分享 1、初始对象1.1、类的定义和使用1.2、类的私有成员和方法1.3、类的构造方法 2、继承的实现和属性的使用 1、初始对象 1.1、类的定义和使用 1.2、类的私有成员和方法 class Clock:id Noneprice Nonedef ring(self):import winsoundwinsound.Beep(2000,3000)clock Clo…

走近Z世代——感受约克VRF天氟地水中央空调营造的多重舒适体验

“我对于约克VRF中央空调最满意之处,就在于这个品牌对用户体验的极致追求。”来自浙江绍兴的范先生,在提到自家安装的约克VRF天氟地水中央空调时发出了如此感慨。作为“Z世代”的一员,年轻的范先生与所有95后年轻人一样,有着自身的独特审美需求,有对潮流的想法和坚持,更有着生活…

被问界/理想赶超!奔驰CEO再度“出马”,寻找中国外援

来自中国车企的全方位、持续施压,让大部分外资车企开始寻求更多的本地化合作来实现技术升级。传统豪华品牌也同样如此。 本周,知情人士透露,梅赛德斯奔驰首席执行官Ola Kllenius计划再次访问中国,目的是进一步寻求和扩大与本地技术…

矩阵系统源码搭建,OEM贴牌,源头技术开发

一、引言 矩阵系统在当今的数字化时代中发挥着重要的作用,无论是在社交媒体管理、内容分发还是数据分析等方面,都有着广泛的应用。本文将详细介绍矩阵系统源码搭建所需准备的内容,帮助开发者更好地理解和实施矩阵系统的搭建过程。 二、技术选…

集合框架 - Map双列集合

01 概述 02 常用方法 03 遍历方式 【快捷键】&#xff1a;ctrlaltv 【说明】&#xff1a;Map.Entry<xx,xx>中&#xff0c;Entry是Map集合中的一个接口&#xff0c;但接口是不能创建对象的&#xff0c;它底层是通过使用Entry的实现类对象来封装键值对数据的。 【说明】&a…

Transformer架构分析

1 encoder 每个称之为一个layer&#xff0c;重复N次 每个里面有两个sublayer&#xff1b;multi-head self-attention MLP后面使用layer normalization LayerNorm(x Sublayer(x)) 残差连接需要两个维度一致&#xff0c;本文采用513。 2 decoder 3 注意力机制 输出维度和val…

华为OD机试 - 匿名信(Python/JS/C/C++ 2024 E卷 100分)

华为OD机试 2024E卷题库疯狂收录中&#xff0c;刷题点这里 专栏导读 本专栏收录于《华为OD机试真题&#xff08;Python/JS/C/C&#xff09;》。 刷的越多&#xff0c;抽中的概率越大&#xff0c;私信哪吒&#xff0c;备注华为OD&#xff0c;加入华为OD刷题交流群&#xff0c;…

测试管理新增视图与高级搜索功能,测试计划支持一键生成缺陷详情,MeterSphere开源持续测试工具v3.3版本发布

2024年9月29日&#xff0c;MeterSphere开源持续测试工具正式发布v3.3版本。 在这一版本中&#xff0c;接口测试方面&#xff0c;接口导入功能支持导入Postman、JMX、HAR和MeterSphere格式的文件&#xff0c;接口场景的自定义请求步骤支持cURL快捷导入&#xff1b;测试管理方面…

大数据实时数仓Hologres(一):Hologres 简单介绍

文章目录 Hologres 简单介绍 一、什么是实时数仓 Hologres 二、产品优势 1、专注实时场景 2、亚秒级交互式分析 3、统一数据服务出口 4、开放生态 5、MaxCompute查询加速 6、计算存储分离架构 三、应用场景 搭建实时数仓 四、产品架构 1、Shared Disk/Storage &am…

sql注入工具升级:自动化时间盲注、布尔盲注

项目地址&#xff1a;https://github.com/iamnotamaster/sql-injecter 给我之前写的sql注入脚本进行了一些升级&#xff0c;此文章就是对升级内容的分析&#xff0c;升级内容如下&#xff1a; 使用占位符foo来填充payload里需要经常修改的部分 自动判断循环 支持爆破和二分查…

UE4_Niagara基础实例—6、蓝图与粒子系统的通信

效果图&#xff1a; 分析&#xff1a; 通过键盘按键来修改粒子系统粒子的大小。 步骤&#xff1a; 1、粒子系统使用上一个实例的粒子系统&#xff0c;大体参数如下&#xff1a; 参数都是乱调的&#xff0c;自己可以随意设置&#xff0c;只注重方法而不在意好看&#xff0c;汗…

求5X5的次小值/次大值

我们知道&#xff0c;求最大值和最小值是比较容易的&#xff0c;就是通过分组判断&#xff0c;然后再次比较即可求出&#xff0c;那么求出次小值/次大值怎么实现呢&#xff0c;本文提供一个设计的思路。 以5x5为例&#xff0c;求出次小值&#xff0c; 第一步&#xff0c;先分…