大模型的实践应用9-利用LoRA方法在单个GPU上微调FLAN-T5模型的过程讲解与实现

news2025/1/11 23:49:42

大家好,我是微学AI,今天给大家介绍一下大模型的实践应用9-利用LoRA方法在单个GPU上微调FLAN-T5模型的过程讲解与实现,文本我们将向您展示如何应用大型语言模型的低秩适应(LoRA)在单个GPU上微调FLAN-T5 XXL(110 亿个参数)模型。我们将利用Transformers、Accelerate和PEFT等第三方库。
在这里插入图片描述

1. 设置开发环境

这里我使用已设置好的 CUDA 驱动程序,安装PyTorch深度学习框架,还需要安装 安装 Hugging Face 中的第三方库。

# 安装 Hugging Face 中的第三方库
pip install "peft==0.2.0"
pip install "transformers==4.27.2" "datasets==2.9.0" "accelerate==0.17.1" "evaluate==0.4.0" "bitsandbytes==0.37.1" loralib --upgrade --quiet
# 安装所需的附加依赖项
pip install rouge-score tensorboard py7zr

2. 加载并准备数据集

我们将使用数据集:https://huggingface.co/datasets/samsum。
该数据集包含大约 16k 个带有摘要的对话信息数据。

{
  "id": "13818513",
  "summary": "Amanda baked cookies and will bring Jerry some tomorrow.",
  "dialogue": "Amanda: I baked cookies. Do you want some?\r\nJerry: Sure!\r\nAmanda: I'll bring you tomorrow :-)"
}

要加载samsum数据集,我们使用Datasets 库中的方法load_dataset()。

from datasets import load_dataset

# Load dataset from the hub
dataset = load_dataset("samsum")

print(f"Train dataset size: {len(dataset['train'])}")
print(f"Test dataset size: {len(dataset['test'])}")

为了训练我们的模型,我们需要将输入(文本)转换为Token ID。这是由Transformers Tokenizer 完成的。

from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

model_id="google/flan-t5-xxl"

# Load tokenizer of FLAN-t5-XL
tokenizer = AutoTokenizer.from_pretrained(model_id)

在开始训练之前,我们需要预处理数据。抽象摘要是一项文本生成任务。我们的模型将采用文本作为输入并生成摘要作为输出。我们想要了解我们的输入和输出需要多长时间才能有效地批处理我们的数据。

from datasets import concatenate_datasets
import numpy as np
# 标记后的最大总输入序列长度。长于此的序列将被截断,短于此的序列将被填充。

tokenized_inputs = concatenate_datasets([dataset["train"], dataset["test"]]).map(lambda x: tokenizer(x["dialogue"], truncation=True), batched=True, remove_columns=["dialogue", "summary"])
input_lenghts = [len(x) for x in tokenized_inputs["input_ids"]]

max_source_length = int(np.percentile(input_lenghts, 85))
print(f"Max source length: {max_source_length}")

tokenized_targets = concatenate_datasets([dataset["train"], dataset["test"]]).map(lambda x: tokenizer(x["summary"], truncation=True), batched=True, remove_columns=["dialogue", "summary"])
target_lenghts = [len(x) for x in tokenized_targets["input_ids"]]

max_target_length = int(np.percentile(target_lenghts, 90))
print(f"Max target length: {max_target_length}")

我们在训练之前预处理数据集并将其保存到磁盘。您可以在本地计算机或 CPU 上运行此步骤并将其上传到Hugging Face Hub。

def preprocess_function(sample,padding="max_length"):
    # add prefix to the input for t5
    inputs = ["summarize: " + item for item in sample["dialogue"]]

    # tokenize inputs
    model_inputs = tokenizer(inputs, max_length=max_source_length, padding=padding, truncation=True)

    # Tokenize targets with the `text_target` keyword argument
    labels = tokenizer(text_target=sample["summary"], max_length=max_target_length, padding=padding, truncation=True)

    # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
    # padding in the loss.
    if padding == "max_length":
        labels["input_ids"] = [
            [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
        ]

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

tokenized_dataset = dataset.map(preprocess_function, batched=True, remove_columns=["dialogue", "summary", "id"])
print(f"Keys of tokenized dataset: {list(tokenized_dataset['train'].features)}")

# save datasets to disk for later easy loading
tokenized_dataset["train"].save_to_disk("data/train")
tokenized_dataset["test"].save_to_disk("data/eval")

3. 使用 LoRA 和 bnb int-8 微调 T5模型

LoRA 技术在上一节课我已经介绍了,可以看《大模型的实践应用8-利用PEFT和LoRa技术微调大模型(LLM)的原理介绍与指南》这篇文章。

在这里插入图片描述

除了 LoRA 技术之外,我们还将使用bitanbytes LLM.int8()将冻结的 LLM 量化为 int8。这使我们能够将 FLAN-T5 XXL 所需的内存减少约 4 倍。

我们训练的第一步是加载模型。我们将使用philschmid/flan-t5-xxl-sharded-fp16 ,它是google/flan-t5-xxl的分片版本。分片将帮助我们在加载模型时不会耗尽内存。

from transformers import AutoModelForSeq2SeqLM
from peft import LoraConfig, get_peft_model, prepare_model_for_int8_training, TaskType

model_id = "philschmid/flan-t5-xxl-sharded-fp16"

# 加载模型
model = AutoModelForSeq2SeqLM.from_pretrained(model_id, load_in_8bit=True, device_map="auto")

# 定义 LoRA配置
lora_config = LoraConfig(
 r=16,
 lora_alpha=32,
 target_modules=["q", "v"],
 lora_dropout=0.05,
 bias="none",
 task_type=TaskType.SEQ_2_SEQ_LM
)
# prepare int-8 model for training
model = prepare_model_for_int8_training(model)

# add LoRA adaptor
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()


trainable params: 18874368 || all params: 11154206720 || trainable%: 0.16921300163961817

正如你所看到的,这里我们只训练了模型参数的0.16%!这种巨大的内存增益将使我们能够在没有内存问题的情况下微调模型。

接下来是创建一个DataCollator负责填充我们的输入和标签的对象。我们将使用Transformers 库中的DataCollatorForSeq2Seq 。

from transformers import DataCollatorForSeq2Seq
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments

# we want to ignore tokenizer pad token in the loss
label_pad_token_id = -100
# Data collator
data_collator = DataCollatorForSeq2Seq(
    tokenizer,
    model=model,
    label_pad_token_id=label_pad_token_id,
    pad_to_multiple_of=8
)

output_dir="lora-flan-t5-xxl"

# Define training args
training_args = Seq2SeqTrainingArguments(
    output_dir=output_dir,
	auto_find_batch_size=True,
    learning_rate=1e-3, # higher learning rate
    num_train_epochs=5,
    logging_dir=f"{output_dir}/logs",
    logging_strategy="steps",
    logging_steps=500,
    save_strategy="no",
    report_to="tensorboard",
)

# Create Trainer instance
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=tokenized_dataset["train"],
)
model.config.use_cache = False

现在让我们训练我们的模型并运行下面的代码。对于T5模型,出于稳定性目的保留了一些层float32。

# train model
trainer.train()

我们可以保存模型以用于推理和评估。我们暂时将其保存到磁盘,但你也可以使用该方法将其上传到Hugging Face。

peft_model_id="results"
trainer.model.save_pretrained(peft_model_id)
tokenizer.save_pretrained(peft_model_id)

保存模型以用于推理和评估。我们暂时将其保存到磁盘,但你也可以使用该方法将其上传到Hugging Face。

peft_model_id="results"
trainer.model.save_pretrained(peft_model_id)
tokenizer.save_pretrained(peft_model_id)

最后,我们的 LoRA 检查点只有 84MB大小,他包含了 samsum 的所有学习知识。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1264869.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Java 基础学习(四)操作数组、软件开发管理

1 操作数组 1.1.1 System.arraycopy 方法用于数组复制 当需要将一个数组的元素复制到另一个数组中时,可以使用System.arraycopy方法。它提供了一种高效的方式来复制数组的内容,避免了逐个元素赋值的繁琐过程。相对于使用循环逐个元素赋值的方式&#x…

源码剖析 Spring Security 的实现原理

Spring Security 是一个轻量级的安全框架,可以和 Spring 项目很好地集成,提供了丰富的身份认证和授权相关的功能,而且还能防止一些常见的网络攻击。我在工作中有很多项目都使用了 Spring Security 框架,但基本上都是浅尝辄止&…

443. 压缩字符串

这篇文章会收录到 : 算法通关村第十二关-黄金挑战字符串冲刺题-CSDN博客 压缩字符串 描述 : 给你一个字符数组 chars ,请使用下述算法压缩: 从一个空字符串 s 开始。对于 chars 中的每组 连续重复字符 : 如果这一组长度为 1 ,…

华清远见嵌入式学习——C++——作业一

作业要求&#xff1a; 代码&#xff1a; #include <iostream>using namespace std;int main() {string str;cout << "请输入一个字符串&#xff1a;" << endl;getline(cin,str);int dx0,xx0,sz0,kg0,qt0;int len str.size() 1;for(int i0;i<l…

【读论文】【泛读】S-NERF: NEURAL RADIANCE FIELDS FOR STREET VIEWS

文章目录 0. Abstract1. Introduction2. Related work3. Methods-NERF FOR STREET VIEWS3.1 CAMERA POSE PROCESSING3.2 REPRESENTATION OF STREET SCENES3.3 DEPTH SUPERVISION3.4 Loss function 4. EXPERIMENTS5. ConclusionReference 0. Abstract Problem introduction&…

小狐狸ChatGPT付费创作系统V2.3.4独立版 +WEB端+ H5端最新去弹窗授权

ChatGPT付费创作系统V2.3.4版本优化了很多细节&#xff0c;如果使用着2.2.9版本建议没升级的必要。该版本为编译版无开源&#xff0c;2.3.X版本开始官方植入了更多的后门和更隐性的弹窗代码&#xff0c;后门及弹窗处理起来更麻烦。特别针对后台弹窗网址、暗链后门网址全部进行了…

UDS 相关时间参数

文章目录 UDS 全部时间参数UDS 应用层诊断时间参数1、P2 Client P2 Server P2* Client P2* Server 图例2、S3 Client S3 Server 图例 UDS CNA-TP网络层时间参数1、N_As/N_Ar 图例2、N_Bs 图例3、 N_Br 图例4、N_Cs 图例N_Cr 图例 UDS 网络层流控制时间参数 UDS 全部时间参数 UD…

vue3+ts 全局函数和变量的使用

<template><div>{{ $env }}<br />{{ $filters.format("的飞机") }}</div> </template><script setup lang"ts"> import { getCurrentInstance } from "vue"; const app getCurrentInstance(); console.log…

CVE-2020-11651(SaltStack认证绕过)漏洞复现

简介 SaltStack是使用Python开发的一个服务器基础架构集中化管理平台,底层采用动态的连接总线,使其可以用于编配,远程执行, 配置管理等等。 Salt非常容易设置和维护,而不用考虑项目的大小。从数量可观的本地网络系统,到跨数据中心的互联网部署,Salt设计为在任意数量的…

matlab频谱合成音乐《追光者》

选择你喜欢的一首钢琴曲&#xff0c;下载并分析曲谱&#xff0c;用matlab工具用频谱合成方法完成这首曲子的音乐合成。 前言&#xff1a;此文章为个人使用Matlab合成一首《追光者》音乐&#xff0c;且带混响和声效果 文章目录 一.题目二.要求三.课程设计目的四.概要设计五.详细…

Django项目部署本地windows IIS(详细版)和static文件设置(页面样式正常显示)

目录 必要条件&#xff1a; 一、下载并启用wfastcgi 二、window安装 IIS功能 三、IIS管理器中添加网站 1、复制项目 2、复制wfastcgi.py文件 3、创建文件web.config 4、添加网站&#xff0c;填写信息 5、启动fastcgi程序 6、修改进程标识 四、static文件设置和正确显…

吉利展厅 | 透明OLED拼接2x2:科技与艺术的完美融合

产品&#xff1a;4块55寸OLED透明拼接屏 项目地点&#xff1a;南宁 项目时间&#xff1a;2023年11月 应用场景&#xff1a;吉利展厅 在2023年11月的南宁&#xff0c;吉利展厅以其独特的展示设计吸引了众多参观者的目光。其中最引人注目的亮点是展厅中央一个由四块55寸OLED透…

pandas教程:USDA Food Database USDA食品数据库

文章目录 14.4 USDA Food Database&#xff08;美国农业部食品数据库&#xff09; 14.4 USDA Food Database&#xff08;美国农业部食品数据库&#xff09; 这个数据是关于食物营养成分的。存储格式是JSON&#xff0c;看起来像这样&#xff1a; {"id": 21441, &quo…

4、stable diffusion

github 安装anaconda环境 conda env create -f environment.yaml conda activate ldm安装依赖 conda install pytorch1.12.1 torchvision0.13.1 torchaudio0.12.1 cudatoolkit11.3 -c pytorch pip install transformers4.19.2 diffusers invisible-watermark pip install -e…

快速筛出EXCEL行中的重复项

比如A列是一些恶意IP需要导入防火墙&#xff0c;但包括一些重复项&#xff0c;为不产生错误&#xff0c;需要把重复项筛出来&#xff1a; 1、给A列排序&#xff0c;让重复项的内容排在相邻的行 2、在B列中写一个条件函数&#xff1a;IF(A1A2,1,0)&#xff0c;然后下拉至行尾完成…

2023-简单点-机器学习中常用的特殊函数,激活函数[sigmoid tanh ]

机器学习中的特殊函数 Sigmoidsoftplus函数tanhReLu(x)Leaky-ReluELUSiLu/ SwishMish伽玛函数beta函数Ref Sigmoid 值域: 【0,1】 定义域&#xff1a;【负无穷,正无穷】 特殊点记忆&#xff1a; 经过 [0 , 0.5] 关键点[0,0.5]处的导数是 0.025 相关导数&#xff1a; softplu…

群晖NAS配置之自有服务器frp实现内网穿透

什么是frp frp 是一个专注于内网穿透的高性能的反向代理应用&#xff0c;支持 TCP、UDP、HTTP、HTTPS 等多种协议&#xff0c;且支持 P2P 通信。可以将内网服务以安全、便捷的方式通过具有公网 IP 节点的中转暴露到公网。今天跟大家分享一下frp实现内网穿透 为什么使用 frp &a…

selenium 工具 的基本使用

公司每天要做工作汇报&#xff0c;汇报使用的网页版&#xff0c; 所以又想起 selenium 这个老朋友了。 再次上手&#xff0c;发现很多接口都变了&#xff0c; 怎么说呢&#xff0c; 应该是易用性更强了&#xff0c; 不过还是得重新看看&#xff0c; 我这里是python3。 pip安装…

Blender动画导入Three.js

你是否在把 Blender 动画导入你的 ThreeJS 游戏(或项目)中工作时遇到问题? 您的 .glb (glTF) 文件是否正在加载,但没有显示任何内容? 你的骨骼没有正确克隆吗? 如果是这样,请阅读我如何使用 SkeletonUtils.js 解决此问题 1、前提条件 你正在使用 Blender 3.1+(此版本…

微服务--03--OpenFeign 实现远程调用 (负载均衡组件SpringCloudLoadBalancer)

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 OpenFeign其作用就是基于SpringMVC的常见注解&#xff0c;帮我们优雅的实现http请求的发送。 RestTemplate实现了服务的远程调用 OpenFeign快速入门负载均衡组件Spr…