在自定义数据集上微调Alpaca和LLaMA

news2025/2/4 18:04:51

本文将介绍使用LoRa在本地机器上微调Alpaca和LLaMA,我们将介绍在特定数据集上对Alpaca LoRa进行微调的整个过程,本文将涵盖数据处理、模型训练和使用流行的自然语言处理库(如Transformers和hugs Face)进行评估。此外还将介绍如何使用grado应用程序部署和测试模型。

配置

首先,alpaca-lora1 GitHub存储库提供了一个脚本(finetune.py)来训练模型。在本文中,我们将利用这些代码并使其在Google Colab环境中无缝地工作。

首先安装必要的依赖:

 !pip install -U pip
 !pip install accelerate==0.18.0
 !pip install appdirs==1.4.4
 !pip install bitsandbytes==0.37.2
 !pip install datasets==2.10.1
 !pip install fire==0.5.0
 !pip install git+https://github.com/huggingface/peft.git
 !pip install git+https://github.com/huggingface/transformers.git
 !pip install torch==2.0.0
 !pip install sentencepiece==0.1.97
 !pip install tensorboardX==2.6
 !pip install gradio==3.23.0

安装完依赖项后,继续导入所有必要的库,并为matplotlib绘图配置设置:

 import transformers
 import textwrap
 from transformers import LlamaTokenizer, LlamaForCausalLM
 import os
 import sys
 from typing import List
 
 from peft import (
     LoraConfig,
     get_peft_model,
     get_peft_model_state_dict,
     prepare_model_for_int8_training,
 )
 
 import fire
 import torch
 from datasets import load_dataset
 import pandas as pd
 
 import matplotlib.pyplot as plt
 import matplotlib as mpl
 import seaborn as sns
 from pylab import rcParams
 
 %matplotlib inline
 sns.set(rc={'figure.figsize':(10, 7)})
 sns.set(rc={'figure.dpi':100})
 sns.set(style='white', palette='muted', font_scale=1.2)
 
 DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 DEVICE

数据

我们这里使用BTC Tweets Sentiment dataset4,该数据可在Kaggle上获得,包含大约50,000条与比特币相关的tweet。为了清理数据,删除了所有以“转发”开头或包含链接的推文。

使用Pandas来加载CSV:

 df = pd.read_csv("bitcoin-sentiment-tweets.csv")
 df.head()

通过清理的数据集有大约1900条推文。

情绪标签用数字表示,其中-1表示消极情绪,0表示中性情绪,1表示积极情绪。让我们看看它们的分布:

 df.sentiment.value_counts()
 
 
 # 0.0    860
 # 1.0    779
 # -1.0    258
 # Name: sentiment, dtype: int64

数据量差不多,虽然负面评论较少,但是可以简单的当成平衡数据来对待:

 df.sentiment.value_counts().plot(kind='bar');

构建JSON数据集

原始Alpaca存储库中的dataset5格式由一个JSON文件组成,该文件具有具有指令、输入和输出字符串的对象列表。

让我们将Pandas的DF转换为一个JSON文件,该文件遵循原始Alpaca存储库中的格式:

 def sentiment_score_to_name(score: float):
     if score > 0:
         return "Positive"
     elif score < 0:
         return "Negative"
     return "Neutral"
 
 dataset_data = [
     {
         "instruction": "Detect the sentiment of the tweet.",
         "input": row_dict["tweet"],
         "output": sentiment_score_to_name(row_dict["sentiment"])
     }
     for row_dict in df.to_dict(orient="records")
 ]
 
 dataset_data[0]

结果如下:

 {
   "instruction": "Detect the sentiment of the tweet.",
   "input": "@p0nd3ea Bitcoin wasn't built to live on exchanges.",
   "output": "Positive"
 }

然后就是保存生成的JSON文件,以便稍后使用它来训练模型:

 import json
 with open("alpaca-bitcoin-sentiment-dataset.json", "w") as f:
    json.dump(dataset_data, f)

模型权重

虽然原始的Llama模型权重不可用,但它们被泄露并随后被改编用于HuggingFace Transformers库。我们将使用decapoda-research6:

 BASE_MODEL = "decapoda-research/llama-7b-hf"
 
 model = LlamaForCausalLM.from_pretrained(
     BASE_MODEL,
     load_in_8bit=True,
     torch_dtype=torch.float16,
     device_map="auto",
 )
 
 tokenizer = LlamaTokenizer.from_pretrained(BASE_MODEL)
 
 tokenizer.pad_token_id = (
     0  # unk. we want this to be different from the eos token
 )
 tokenizer.padding_side = "left"

这段代码使用来自Transformers库的LlamaForCausalLM类加载预训练的Llama 模型。load_in_8bit=True参数使用8位量化加载模型,以减少内存使用并提高推理速度。

代码还使用LlamaTokenizer类为同一个Llama模型加载标记器,并为填充标记设置一些附加属性。具体来说,它将pad_token_id设置为0以表示未知的令牌,并将padding_side设置为“left”以填充左侧的序列。

数据集加载

现在我们已经加载了模型和标记器,下一步就是加载之前保存的JSON文件,使用HuggingFace数据集库中的load_dataset()函数:

 data = load_dataset("json", data_files="alpaca-bitcoin-sentiment-dataset.json")
 data["train"]

结果如下:

 Dataset({
     features: ['instruction', 'input', 'output'],
     num_rows: 1897
 })

接下来,我们需要从加载的数据集中创建提示并标记它们:

 def generate_prompt(data_point):
     return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.  # noqa: E501
 ### Instruction:
 {data_point["instruction"]}
 ### Input:
 {data_point["input"]}
 ### Response:
 {data_point["output"]}"""
 
 
 def tokenize(prompt, add_eos_token=True):
     result = tokenizer(
         prompt,
         truncation=True,
         max_length=CUTOFF_LEN,
         padding=False,
         return_tensors=None,
     )
     if (
         result["input_ids"][-1] != tokenizer.eos_token_id
         and len(result["input_ids"]) < CUTOFF_LEN
         and add_eos_token
     ):
         result["input_ids"].append(tokenizer.eos_token_id)
         result["attention_mask"].append(1)
 
     result["labels"] = result["input_ids"].copy()
 
     return result
 
 def generate_and_tokenize_prompt(data_point):
     full_prompt = generate_prompt(data_point)
     tokenized_full_prompt = tokenize(full_prompt)
     return tokenized_full_prompt

第一个函数generate_prompt从数据集中获取一个数据点,并通过组合指令、输入和输出值来生成提示。第二个函数tokenize接收生成的提示,并使用前面定义的标记器对其进行标记。它还向输入序列添加序列结束标记,并将标签设置为与输入序列相同。第三个函数generate_and_tokenize_prompt结合了前两个函数,生成并标记提示。

数据准备的最后一步是将数据集分成单独的训练集和验证集:

 train_val = data["train"].train_test_split(
     test_size=200, shuffle=True, seed=42
 )
 train_data = (
     train_val["train"].map(generate_and_tokenize_prompt)
 )
 val_data = (
     train_val["test"].map(generate_and_tokenize_prompt)
 )

我们还需要数据进行打乱,并且获取200个样本作为验证集。generate_and_tokenize_prompt()函数应用于训练和验证集中的每个示例,生成标记化的提示。

训练

训练过程需要几个参数,这些参数主要来自原始存储库中的微调脚本:

 LORA_R = 8
 LORA_ALPHA = 16
 LORA_DROPOUT= 0.05
 LORA_TARGET_MODULES = [
     "q_proj",
     "v_proj",
 ]
 
 BATCH_SIZE = 128
 MICRO_BATCH_SIZE = 4
 GRADIENT_ACCUMULATION_STEPS = BATCH_SIZE // MICRO_BATCH_SIZE
 LEARNING_RATE = 3e-4
 TRAIN_STEPS = 300
 OUTPUT_DIR = "experiments"

下面就可以为训练准备模型了:

 model = prepare_model_for_int8_training(model)
 config = LoraConfig(
     r=LORA_R,
     lora_alpha=LORA_ALPHA,
     target_modules=LORA_TARGET_MODULES,
     lora_dropout=LORA_DROPOUT,
     bias="none",
     task_type="CAUSAL_LM",
 )
 model = get_peft_model(model, config)
 model.print_trainable_parameters()
 
 #trainable params: 4194304 || all params: 6742609920 || trainable%: 0.06220594176090199

我们使用LORA算法初始化并准备模型进行训练,通过量化可以减少模型大小和内存使用,而不会显着降低准确性。

LoraConfig7是一个为LORA算法指定超参数的类,例如正则化强度(lora_alpha)、dropout概率(lora_dropout)和要压缩的目标模块(target_modules)。

然后就可以直接使用Transformers库进行训练:

 training_arguments = transformers.TrainingArguments(
     per_device_train_batch_size=MICRO_BATCH_SIZE,
     gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
     warmup_steps=100,
     max_steps=TRAIN_STEPS,
     learning_rate=LEARNING_RATE,
     fp16=True,
     logging_steps=10,
     optim="adamw_torch",
     evaluation_strategy="steps",
     save_strategy="steps",
     eval_steps=50,
     save_steps=50,
     output_dir=OUTPUT_DIR,
     save_total_limit=3,
     load_best_model_at_end=True,
     report_to="tensorboard"
 )

这段代码创建了一个TrainingArguments对象,该对象指定用于训练模型的各种设置和超参数。这些包括:

  • gradient_accumulation_steps:在执行向后/更新之前累积梯度的更新步数。
  • warmup_steps:优化器的预热步数。
  • max_steps:要执行的训练总数。
  • learning_rate:学习率。
  • fp16:使用16位精度进行训练。

DataCollatorForSeq2Seq是transformer库中的一个类,它为序列到序列(seq2seq)模型创建一批输入/输出序列。在这段代码中,DataCollatorForSeq2Seq对象用以下参数实例化:

 data_collator = transformers.DataCollatorForSeq2Seq(
     tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
 )

pad_to_multiple_of:表示最大序列长度的整数,四舍五入到最接近该值的倍数。

padding:一个布尔值,指示是否将序列填充到指定的最大长度。

以上就是训练的所有代码准备,下面就是训练了

 trainer = transformers.Trainer(
     model=model,
     train_dataset=train_data,
     eval_dataset=val_data,
     args=training_arguments,
     data_collator=data_collator
 )
 model.config.use_cache = False
 old_state_dict = model.state_dict
 model.state_dict = (
     lambda self, *_, **__: get_peft_model_state_dict(
         self, old_state_dict()
     )
 ).__get__(model, type(model))
 
 model = torch.compile(model)
 
 trainer.train()
 model.save_pretrained(OUTPUT_DIR)

在实例化训练器之后,代码在模型的配置中将use_cache设置为False,并使用get_peft_model_state_dict()函数为模型创建一个state_dict,该函数为使用低精度算法进行训练的模型做准备。

然后在模型上调用torch.compile()函数,该函数编译模型的计算图并准备使用PyTorch 2进行训练。

训练过程在A100上持续了大约2个小时。我们看一下Tensorboard上的结果:

训练损失和评估损失呈稳步下降趋势。看来我们的微调是有效的。

如果你想将模型上传到Hugging Face上,可以使用下面代码,

 from huggingface_hub import notebook_login
 
 notebook_login()
 model.push_to_hub("curiousily/alpaca-bitcoin-tweets-sentiment", use_auth_token=True)

推理

我们可以使用generate.py脚本来测试模型:

 !git clone https://github.com/tloen/alpaca-lora.git
 %cd alpaca-lora
 !git checkout a48d947

我们的脚本启动的gradio应用程序

 !python generate.py \
     --load_8bit \
     --base_model 'decapoda-research/llama-7b-hf' \
     --lora_weights 'curiousily/alpaca-bitcoin-tweets-sentiment' \
     --share_gradio

简单的界面如下:

总结

我们已经成功地使用LoRa方法对Llama 模型进行了微调,还演示了如何在Gradio应用程序中使用它。

如果你对本文感兴趣,请看原文:

https://avoid.overfit.cn/post/34b6eaf7097a4929b9aab7809f3cfeaa

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/761815.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

分型(一)

一点想法&#xff1a; 坐标系本来就是人头脑当中的东西&#xff0c;可以根据需要动态调整&#xff0c;图像处理中&#xff0c;用动态的波看静态的图像&#xff0c;进而找出图像的特征&#xff0c;分型是不是把动态的过程呈现出静态的特征呢&#xff1f; 芒德布罗集&#xff1…

Python(十五)数据类型——bool型

❤️ 专栏简介&#xff1a;本专栏记录了我个人从零开始学习Python编程的过程。在这个专栏中&#xff0c;我将分享我在学习Python的过程中的学习笔记、学习路线以及各个知识点。 ☀️ 专栏适用人群 &#xff1a;本专栏适用于希望学习Python编程的初学者和有一定编程基础的人。无…

【动手学深度学习】--08.实战:Kaggle房价预测

文章目录 实战&#xff1a;Kaggle房价预测1.访问和读取数据集2.数据预处理2.1标准化数据2.2离散数据处理 3.训练4.K折交叉验证5.模型选择 实战&#xff1a;Kaggle房价预测 1.访问和读取数据集 # 如果没有安装pandas&#xff0c;请取消下一行的注释 # !pip install pandas%matp…

华为云CodeArts产品体验的心得体会及想法

文章目录 前言CodeArts 的产品优势一站式软件开发生产线研发安全Built-In华为多年研发实践能力及规范外溢高质高效敏捷交付 功能特性说明体验感受问题描述完结 前言 华为云作为一家全球领先的云计算服务提供商&#xff0c;致力于为企业和个人用户提供高效、安全、可靠的云服务。…

基于SpringBoot+vue的在线答疑系统设计与实现

博主介绍&#xff1a; 大家好&#xff0c;我是一名在Java圈混迹十余年的程序员&#xff0c;精通Java编程语言&#xff0c;同时也熟练掌握微信小程序、Python和Android等技术&#xff0c;能够为大家提供全方位的技术支持和交流。 我擅长在JavaWeb、SSH、SSM、SpringBoot等框架…

《实战AI低代码》:普元智能化低代码开发平台发布,结合专有模型大幅提升软件生产力

在7月6日举办的“低代码+AI”产品战略发布会上,普元智能化低代码开发平台正式发布。该平台融合了普元自主研发的专有模型,同时也接入了多款AI大模型的功能。它提供了一系列低代码产品,包括中间件、业务分析、应用开发、数据中台和业务流程自动化等,旨在简化企业的复杂软件生…

Linux网络基础 — 传输层

目录 再谈端口号 端口号范围划分 认识知名端口号 netstat pidof UDP协议 UDP协议端格式 UDP的特点 面向数据报 UDP的缓冲区 UDP使用注意事项 基于UDP的应用层协议 TCP协议 TCP协议段格式 几个问题&#xff1a; 确认应答(ACK)机制 6个标记位 超时重传机制 连接…

rust 引用怎么用

本来好好的引用&#xff0c;被 rust 玩坏了&#xff0c;搞得自己都不会使用引用了&#xff0c;我们还是从简单的例子入手&#xff0c;来探索使用引用可能遇到额问题。 下面的示例代码编译不通过&#xff0c;在 s1 赋值给变量 s2 的过程中&#xff0c;字符串 neojos 值的所有权…

webGPT浏览器安装

edge点击“获取” https://microsoftedge.microsoft.com/addons/detail/wetab%E5%85%8D%E8%B4%B9chatgpt%E6%96%B0%E6%A0%87%E7%AD%BE%E9%A1%B5/bpelnogcookhocnaokfpoeinibimbeff?utm_sourceSteamDB # 其他浏览器安装教程如下&#xff1a; https://www.wetab.link/

springcloud gateway 介绍与使用

定义 该项目提供了一个用于在 Spring WebFlux 之上构建 API 网关的库。Spring Cloud Gateway 旨在提供一种简单而有效的方法来路由到 API&#xff0c;并为它们提供横切关注点&#xff0c;例如&#xff1a;安全性、监控/指标和弹性 特征 Spring Cloud Gateway 特性&#xff1a;…

Smartbi 身份认证绕过漏洞

0x00 简介 Smartbi是广州思迈特软件有限公司旗下的商业智能BI和数据分析产品&#xff0c;致力于为企业客户提供一站式商业智能解决方案。 0x01 漏洞概述 Smartbi在安装时会内置三个用户&#xff08;public、service、system&#xff09;&#xff0c;在使用特定接口时&#x…

只会用插件可不行,这些前端动画技术同样值得收藏-JavaScript篇(上)

目录 前言 settimeout/setinterval requestAnimationFrame 基本用法 时间戳参数 帧数与时间戳计算 自动暂停 JS中的贝塞尔曲线 概念 公式 二次贝塞尔 三次贝塞尔 N次贝塞尔 贝塞尔曲线动画 动画类 在动画中使用贝塞尔 总结 相关代码&#xff1a; 贝塞尔曲线相…

【深度学习】:《PyTorch入门到项目实战》(十五):三大经典卷积神经网络架构:LeNet、AlexNet、VGG

【深度学习】&#xff1a;《PyTorch入门到项目实战》(十五)&#xff1a;三大经典卷积神经网络架构&#xff1a;LeNet、AlexNet、VGG&#xff08;代码实现及实际案例比较&#xff09; ✨本文收录于【深度学习】&#xff1a;《PyTorch入门到项目实战》专栏&#xff0c;此专栏主要…

数据结构与算法——顺序表的基本操作(C语言详解版)

顺序表插入元素 向已有顺序表中插入数据元素&#xff0c;根据插入位置的不同&#xff0c;可分为以下 3 种情况&#xff1a; 插入到顺序表的表头&#xff1b;在表的中间位置插入元素&#xff1b;尾随顺序表中已有元素&#xff0c;作为顺序表中的最后一个元素&#xff1b; 虽然…

GaussDB OLTP 云数据库配套工具DAS

目录 一 、前言 二、DAS的定义 1、DAS的定义 2、DAS功能特点 三、DAS应用场景 1、标准版 2、企业版 四、操作示例&#xff08;标准版&#xff09; 1、登录华为控制台登录&#xff0c;输入账号密码 2、新增数据库实例链接 3、新建对象 4、SQL操作 5、导入导出 五、…

OC多态性浅析

OC多态性浅析 小实验 假设有以下两个类classA与class B的声明与实现&#xff1a; /// classA.h#ifndef classA_h #define classA_h#import <Foundation/Foundation.h>interface classA : NSObject-(void) printVar;end#endif /* classA_h *//// classB.h#ifndef class…

微信小程序事件点击跳转页面的五种方法

第一种:标签 这是最常见的一种跳转方式,相当于html里的a标签 <navigator url"/pages/main/main"></navigator>第二种:wx.navigateTo({})方法 1.前端wxml <button bindtap"getCeshi" type"primary"> 测试按钮 </button>…

基于SpringBoot+vue的外卖点餐系统设计与实现

博主介绍&#xff1a; 大家好&#xff0c;我是一名在Java圈混迹十余年的程序员&#xff0c;精通Java编程语言&#xff0c;同时也熟练掌握微信小程序、Python和Android等技术&#xff0c;能够为大家提供全方位的技术支持和交流。 我擅长在JavaWeb、SSH、SSM、SpringBoot等框架…

【Android知识笔记】应用进程(二)

Service的启动原理 向AMS发送startService请求 startService时会首先拿到AMS的Binder代理对象,向AMS发起startService请求: AMS处理startService请求 接下来看AMS端处理应用的startService请求: 回忆一下应用进程启动流程: 接下来看如果Service所在应用进程没有启动的情…

GitLab 私有 Go Modules 的搭建配置

配置步骤 配置 GitLab Access Token 将 GitLab Access Token 写入到 ~/.netrc 文件里(也就是根目录下的 .netrc 文件),没有则新建,并写入内容:machine git.tenorshare.cn login ${USER_NAME} password ${GITLAB_ACCESS_TOKEN},其中 ${USER_NAME} 是GitLab 用户名,${GIT…