如何使用自有数据微调ChatGLM-6B

news2024/12/29 10:25:31

构建自己的数据集

数据格式:问答对

官网例子

ADGEN 数据集任务为根据输入(content)生成一段广告词(summary)。

{  
    "content": "类型#上衣*版型#宽松*版型#显瘦*图案#线条*衣样式#衬衫*衣袖型#泡泡袖*衣款式#抽绳",  
    "summary": "这件衬衫的款式非常的宽松,利落的线条可以很好的隐藏身材上的小缺点,穿在身上有着很好的显瘦效果。领口装饰了一个可爱的抽绳,漂亮的绳结展现出了十足的个性,配合时尚的泡泡袖型,尽显女性甜美可爱的气息。"  
}

调整自有样本格式

结合您的任务场景,构建问题和答案,以提取关键字为例。

[
{  
    "content": "请提取下面句子的关键字:'''离岸人民币 (CNH) 兑美元北京时间04:59报7.1657元,较周二纽约尾盘上涨462点,盘中整体交投于7.2155-7.1617元区间。截至发稿,离岸人民币暂报7.1656,升值12基点。Wind数据显示,7月份以来美元指数持续下跌,12日更是大跌1.06%。与此同时,在岸、离岸人民币对美元汇率迎来反弹,在岸、离岸人民币双双收复7.2关。自上周四起,人民币对美元连续走高,截至12日收盘,离岸人民币对美元5个交易日累计涨幅达956个基点。'''",  
    "summary": "离岸人民币 反弹"  
},
{  
    "content": "请提取下面句子的关键字:'''连日来,人民币汇率强势回升,给市场留下深刻“记忆”。

  昨日晚间,在岸、离岸人民币汇率双双收复7.17关口,其中,离岸人民币汇率日内大涨近400点。而近一周以来,在岸、离岸人民币汇率强势回升近千点,升幅均超过1%。

  关于人民币汇率,中国人民银行行长易纲近期在《经济研究》发表《货币政策的自主性、有效性与经济金融稳定》一文。易纲在文中指出,“近年来人民币汇率弹性显著增强,提高了利率调控的自主性,促进了宏观经济稳定,经济基本面稳定又对汇率稳定形成支撑,外汇市场运行更有韧性,利率和汇率之间形成良性互动。”'''",  
    "summary": "人民币汇率 回升 强势 弹性 韧性 支撑"  
},
]

数据集划分和数据量

起码得有几百个吧,除了准备训练集,您还要准备一个验证集,测试与验证的比例可参考8:2、8:1等。如果您有测试集更好了,可以评测下效果。

微调方法

选用官网ptuningv2微调方法,去学习提示向量。对于 ChatGLM-6B 模型基于 P-Tuning v2 的微调。P-Tuning v2 将需要微调的参数量减少到原来的 0.1%,再通过模型量化、Gradient Checkpoint 等方法,最低只需要 7GB 显存即可运行。

脚本参数如何修改

# P-tuning v2
!PRE_SEQ_LEN=128 && LR=2e-2 && CUDA_VISIBLE_DEVICES=0 python3 main.py \
    --do_train \
    --train_file AdvertiseGen/train.json \
    --validation_file AdvertiseGen/dev.json \
    --prompt_column content \
    --response_column summary \
    --overwrite_cache \
    --model_name_or_path /home/mw/input/ChatGLM6B6449 \
    --output_dir /home/output/adgen-chatglm-6b-pt-$PRE_SEQ_LEN-$LR \
    --overwrite_output_dir \
    --max_source_length 64 \
    --max_target_length 64 \
    --per_device_train_batch_size 4 \
    --per_device_eval_batch_size 1 \
    --gradient_accumulation_steps 4 \
    --predict_with_generate \
    --max_steps 3000 \
    --logging_steps 10 \
    --save_steps 1000 \
    --learning_rate $LR \
    --pre_seq_len $PRE_SEQ_LEN \
    --quantization_bit 4
# 重要参数注释
--train_file AdvertiseGen/train.json \ # 训练集样本路径
--validation_file AdvertiseGen/dev.json \ # 验证集样本路径
--prompt_column content \ # 样本集json文件中问题/提示的key
--response_column summary \ # 样本集json文件中答案的key
--max_source_length 64 \ # 输入的token最大长度
--max_target_length 64 \ # 输出的token最大长度
--per_device_train_batch_size 4 \
--gradient_accumulation_steps 4 \
--max_steps 3000 \
--quantization_bit 4 # 使用int4量化

--model_name_or_path /home/mw/input/ChatGLM6B6449 \ # 原始预训练模型文件存放位置
--output_dir /home/output/adgen-chatglm-6b-pt-$PRE_SEQ_LEN-$LR \ # 微调生成的模型检查点/训练后的模型参数存放位置;后续推理预测时,需要再次加载此目录下的最后一个文件夹

官网对于参数的解释:

P-tuning v2
PRE_SEQ_LEN 和 LR 分别是 soft prompt 长度和训练的学习率,可以进行调节以取得最佳的效果。P-Tuning-v2 方法会冻结全部的模型参数,可通过调整 quantization_bit 来被原始模型的量化等级,不加此选项则为 FP16 精度加载。

在默认配置 quantization_bit=4、per_device_train_batch_size=1、gradient_accumulation_steps=16 下,INT4 的模型参数被冻结,一次训练迭代会以 1 的批处理大小进行 16 次累加的前后向传播,等效为 16 的总批处理大小,此时最低只需 6.7G 显存。若想在同等批处理大小下提升训练效率,可在二者乘积不变的情况下,加大 per_device_train_batch_size 的值,但也会带来更多的显存消耗,请根据实际情况酌情调整。

上文以 P-tuning v2 方法采取的参数 quantization_bit=4、per_device_train_batch_size=4、gradient_accumulation_steps=4

如何防止过拟合和灾难遗忘

轮数不要太多。,具体多少是个玄学,如果训练样本有很少,建议先试试1轮。如果样本很多,1w+,建议先试试0.5轮。

轮数的计算方法:per_device_train_batch_size*gradient_accumulation_steps*max_steps/样本总数

微调后效果评估

import os
import torch
from transformers import AutoConfig, AutoModel, AutoTokenizer
# 假如大约12G左右的预训练模型文件放在如下目录
model_path = "/home/mw/input/ChatGLM6B6449"
# 载入Tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
# Fine-tuning 后的表现测试
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True, pre_seq_len=128)
model = AutoModel.from_pretrained(model_path, config=config, trust_remote_code=True)
# 此处使用你的微调后得到的模型检查点目录,
#例如上文我们设置检查点目录为/home/output/adgen-chatglm-6b-pt-128-2e-2/,还要指定最终的检查点目录
#因为前面脚本设置了每1000step生成一个临时检查点,请使用最终那个检查点目录,即max-steps对应的目录。
prefix_state_dict = torch.load(os.path.join("/home/output/adgen-chatglm-6b-pt-128-2e-2/checkpoint-3000", "pytorch_model.bin"))
new_prefix_state_dict = {}
for k, v in prefix_state_dict.items():
    new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)

#V100 机型上可以不进行量化
#print(f"Quantized to 4 bit")
#model = model.quantize(4)
model = model.half().cuda()
model.transformer.prefix_encoder.float()
model = model.eval()

response, history = model.chat(tokenizer, "类型#上衣\*材质#牛仔布\*颜色#白色\*风格#简约\*图案#刺绣\*衣样式#外套\*衣款式#破洞", history=[])
response
# 以下是微调前的效果评估,即没有加载微调后的检查点文件。
model = AutoModel.from_pretrained(model_path, trust_remote_code=True).half().cuda()
model = model.eval()

response, history = model.chat(tokenizer, "类型#上衣\*材质#牛仔布\*颜色#白色\*风格#简约\*图案#刺绣\*衣样式#外套\*衣款式#破洞", history=[])
response

以上如有问题,欢迎评论区回复和交流~

微调方法简介(理论)

p*tuning 论文综述https://arxiv.org/pdf/2107.13586.pdf
p-tuning v1 论文https://arxiv.org/pdf/2103.10385.pdf
p-tuningv2 论文https://arxiv.org/pdf/2110.07602.pdf
prefix-tuning 论文https://aclanthology.org/2021.acl-long.353.pdf
Prompt Tuning 论文https://arxiv.org/pdf/2104.08691.pdf

​

 以上方法均属于软提示(连续提示),区别硬提示(提示工程,或者理解为用自然语言去提示,也叫离散提示)。

软提示微调的目标,去自动学习一些参数/向量,来模拟人工提示工程;即让模型在嵌入式空间自己学习一个提示向量,加到原来的输入之前,再去激活大模型的“潜力”。

谷歌的prompt tuning,这个名字有争议,因为目前还没有明确的分类和命名方法。学习一个嵌入向量,拼到原有输入的向量表示之前,然后喂给后续的预训练语言模型。如下如所示。

P-tuning v2:在v1的基础上,每一层transformer都加上一个要学习的提示向量。 如下图。

prefix-tuning:全参微调(下图顶部)会更新所有LM参数(红色的Transformer框),并要求为每个任务存储一个完整的模型副本。前缀调整(下图底部),它冻结LM参数并仅优化前缀(红色前缀块)。因此,只需要存储每个任务的前缀,使前缀调整模块化和空间高效。注意,每个垂直块表示一个时间步长的转换器活动。它也是在每个transformer前添加前缀序列来完成目标的。

 优劣之分,目前只体现在论文的比较之中谷歌的prompt tuning吐槽前缀调整包括编码器和解码器网络上的前缀,而提示调整只需要编码器上的提示。v2吐槽v1对复杂任务效果不加。结论v1和谷歌的prompt tuning比较类似,prefix-tuning和v2比较类似。

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/766934.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【雕爷学编程】Arduino动手做(22)——8X8 LED点阵MAX7219屏2

37款传感器与模块的提法,在网络上广泛流传,其实Arduino能够兼容的传感器模块肯定是不止37种的。鉴于本人手头积累了一些传感器和模块,依照实践出真知(一定要动手做)的理念,以学习和交流为目的,这…

AI图像生成无需代码连接集简云数据表的方法

1 场景描述 人工智能的出现,各个领域都开始尝试将AI作为提高工作效率的必备工具。除了AI对话等,越来越多的AI图像生成工具也出现在市场上。这些AI图像生成工具可以自动创建惊人的图像、艺术作品和设计,从而帮助设计师和创意人员更快速地实现其…

下个月要备多少货?伙伴云零代码进销存系统让您一目了然

大量企业的商业模式是销售实体商品,他们需要进销存系统来帮助企业管理好采购、销售、仓储的业务流程,从而更高效稳定的获得利润,因此进销存是企业的核心业务场景。来看看伙伴云零代码进销存系统如何精准计算进货出货数量,让中小企…

unable to get local issuer certificate (_ssl.c:992)‘)]

操作系统mac os python 版本 python3.11 import edge_tts import asyncio TEXT "how are you"print(TEXT) voice zh-CN-YunxiNeural output 4.mp3 rate -4% volume 0% async def my_function():tts edge_tts.Communicate(text TEXT,voice voice,rate rate…

上海汽配IPO上会在即:由镇政府控股,募资还要偿还银行贷款?

近日,上海证券交易所披露的信息显示,上海汽车空调配件股份有限公司(以下简称“上海汽配”)将于7月21日接受上市委审议。据贝多财经了解,上海汽配已于7月13日更新了招股书(上会稿)。 本次冲刺IPO…

学Python编程为什么会对学好数学有帮助呢?

Python编程和数学有什么关系呢?Python的起源是怎样的呢? 我们先来简单认识一下Python,和Python交个朋友。 Python的全拼是P—Y—T—H—O—N,发音是Python,汉语解释是蟒蛇的意思。 我们再来看Python的图标&#xff0c…

STM32实现MLX90614非接触测温串口显示(标准库与HAL库实现)

目录 模块选择 编程环境 MLX90614基本原理 通信协议(SMBus通信,类IIC通信) 代码实现 STM32与模块之间接线表 1.标准库实现温度采集 2.HAL库实现温度采集 模块选择 STM32F103C8T6 MLX90614 非接触式红外测温传感器 编程环境 KEIL5&…

了解交换机接口的链路类型(access、trunk、hybrid)

上一个章节中讲到了vlan的作用及使用,这篇了解一下交换机接口的链路类型和什么情况下使用 vlan在数据包中是如何体现的,在上一篇的时候提到测试了一下,从PC1去访问PC4的时候,只从E0/0/2发送给了E0/0/3这是,因为两个接…

手把手GDB调试

确保你的程序有可调式的信息 使用gcc编译一个程序 ,带上一些额外的参数 -o0 -g-o0 :避免编译器优化,使用最低的优化等级,默认的编译选项 -g :生产调试信息 如果你已经有一个工程demo,使用cmake时注意使用Debug模式&…

Java使用poi-tl生成word文档添加超链接及添加多个超链接情况

首先是生成单个超链接情况,很简单 就是通过字符替换就行,但是替换的value格式是 TextRenderData data.put("attachment",Texts.of("文件名").link("http://wenjianlj文件路径.com").create()); 就是在替换的data&#…

spring复习:(39)注解方式的ProxyFactoryBean

一、定义接口 package cn.edu.tju.study.service;public interface MyService {void myMethod(); }二、定义实现类: package cn.edu.tju.study.service;public class MyServiceImpl implements MyService{Overridepublic void myMethod() {System.out.println(&qu…

认识一个失意的李白:如何制作一个人物生平二维码?

电影《长安三万里》的火爆,又一次唤醒了我们对大唐盛世的憧憬和向往。 飞流直下的瀑布、洒落床前的月光、花间独酌的美酒、胡天八月的大雪、越过青天的白鹭、长河孤烟的大漠、钟鼓馔玉的宴会……每每读起,那景象如在眼前。 对于一代又一代读着唐诗、听…

小程序一码跳多端的实现架构。。。

以常用的小程序,微信,支付宝为例, 现在要实现一个二维码,通过微信扫跳转微信小程序,通过支付包扫,跳转支付宝小程序,(其他小程序也如此) 实现思路,H5页面周转…

社区生鲜超市数字化经营怎么做?社区生鲜超市系统一览

社区生鲜超市是一种以货架自助的形式、结合现代超市经营理念,来售卖果蔬、肉类、水产、粮油、熟食等生鲜产品的一种零售形式,通常为小规模的连锁生鲜超市、专营店,主要服务于一个社区、街区等。目前,社区生鲜超市通常拥有较好的区…

【数据结构】二叉树详解(1)

⭐️ 前言 ✨ 二叉树的概念性质 ⭐️ 二叉树链式结构的实现 结构定义&#xff1a; #include <stdio.h> #include <stdlib.h> #include <assert.h>typedef int BinaryTreeDataType;typedef struct BinaryTreeNode {BinaryTreeDataType value;struct Binary…

如何克服Leetcode做题的困境

文章目录 如何克服Leetcode做题的困境问题背景克服困境的建议实践与理论结合切忌死记硬背分析解题思路不要过早看答案迭代式学习寻求帮助坚持与耐心查漏补缺 结论 如何克服Leetcode做题的困境 问题背景 明明自觉学会了不少知识&#xff0c;可真正开始做Leetcode题目时&#x…

用WooCommerce创建一个多用户商城系统和多供应商市场

线上市场是下一波数字化商务。2020 年&#xff0c;超过60% 的线上支出是通过数字市场发生的。人们喜欢从市场上购物&#xff0c;因为它们使购物变得容易。出于同样的原因&#xff0c;企业喜欢通过它们进行销售。通过多用户商城系统和多供应商WooCommerce商城设置&#xff0c;每…

Vue3结果(Result)

可自定义设置以下属性&#xff1a; 结果的状态&#xff0c;决定图标和颜色&#xff08;status&#xff09;&#xff0c;类型&#xff1a;‘success’|‘error’|‘info’|‘warn’|‘404’|‘403’|‘500’&#xff0c;默认&#xff1a;‘info’标题文字&#xff08;title&…

【机密计算-大厂有话说】IBM

什么是机密计算? 机密计算是云计算技术中的一种,通过 CPU 飞地(enclave,飞地可以理解成与世隔绝的世外桃源)隔离处理中的数据。飞地中的内容(处理中的数据)和处理这些数据用到的技术只能被授权的程序代码访问,对于云提供商以及任何人任何事都不可见也不知道。 随着公…

C# 深入理解事件(event)机制

目录 一&#xff0c;引言 二&#xff0c;事件的定义和用法 2.1 同步事件执行 2.2 异步事件执行 2.3 等待异步事件完成 2.4 捕获异常处理中的异常 一&#xff0c;引言 都知道事件的本质是一个多播委托&#xff08;MulticastDelegate)&#xff0c;但对于事件的机制和用法…