lora微调过程

news2024/11/24 9:01:33
import os
import pickle
from transformers import AutoModelForCausalLM
from peft import get_peft_config, get_peft_model, get_peft_model_state_dict, LoraConfig, TaskType


device = "cuda:0"

#1.创建lora微调基本的配置
peft_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    inference_mode=False,
    r=8,
    lora_alpha=32,
    lora_dropout=0.1,
)

 2.通过调用get_peft_model方法包装基础的Transformer模型

#通过调用get_peft_model方法包装基础的Transformer模型
model = AutoModelForCausalLM("/root/paddlejob/workspace/llama-2-7b-chat")
model = get_peft_model(model, peft_config)

下面是lora微调的模型结构,可以看到多了两个矩阵,一个降维一个升维 

3.训练

# optimizer and lr scheduler
'''len(train_dataloader) 是训练数据集中的批次数量,num_epochs 是训练过程中的迭代次数。因此,len(train_dataloader) * num_epochs 表示整个训练过程中的总迭代次数,即总共要遍历训练数据集的批次数'''
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
lr_scheduler = get_linear_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=(len(train_dataloader) * num_epochs),)

#training and evaluation
model = model.to('cuda:0')
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    for step, batch in enumerate(train_dataloader):
        batch = {k: v.to('cuda:0') for k, v in batch.items()}
        outputs = model(**batch)
        loss = outputs.loss
        total_loss = total_loss + loss.detach().float()
        loss.backward()
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()

    model.eval()
    eval_loss = 0
    eval_preds = []
   for step, batch in enumerate(eval_dataloader):
        batch = {k: v.to('cuda:0') for k, v in bacth.items()}
        with torch.nograd():
            outputs = model(**bacth)
        loss = outputs.loss
        eval_loss = eval_loss + loss.detach().float()
        eval_preds.extend(
            tokenizer.bacth_decode(torch.argmax(outputs.logits, -1), skip_special_tokens=True)
            )
    
    eval_epoch_loss = eval_loss / len(eval_dataloader)
    eval_ppl = torch.exp(eval_epoch_loss)
    train_epoch_loss = total_loss / len(train_dataloader)
    train_ppl = torch.exp(train_epoch_loss)
        

'''
在 Python 中,** 运算符用于将字典解包为关键字参数传递给函数或方法。在 PyTorch 中,model(**batch) 中的 **batch 将字典 batch 中的键值对作为关键字参数传递给模型的方法(通常是前向传播方法)。

具体来说,**batch 将字典 batch 中的每个键值对解包为一组关键字参数。例如,如果 batch 字典包含键值对 {'input_ids': tensor1, 'attention_mask': tensor2},那么 model(**batch) 实际上就等价于 model(input_ids=tensor1, attention_mask=tensor2)。

这种方式可以方便地将字典中的数据传递给函数或方法,并且使代码更加简洁和易读。
'''

'''
loss.detach().float() 的作用是将计算图中的 loss 张量分离出来并转换为浮点数类型。

具体来说:

loss.detach() 会创建一个新的张量,其值与 loss 相同,但不再跟踪梯度信息。这样做是因为在训练过程中,我们通常只需要保存当前步骤的损失值,而不需要其相关的计算图和梯度信息。
.float() 将张量转换为浮点数类型。这是因为通常情况下,损失值是作为浮点数来计算和累加的。
所以,total_loss += loss.detach().float() 的作用就是将当前步骤的损失值添加到总损失值中,保证总损失值是一个浮点数。
'''

'''
在使用 PyTorch 进行梯度下降优化时,optimizer.zero_grad() 的作用是将模型参数的梯度归零,以便进行新一轮的梯度计算和更新。这是因为在 PyTorch 中,每次调用 .backward() 方法都会累积梯度,而不是覆盖之前的梯度。因此,在每次迭代更新参数之前,需要先将之前的梯度清零,以免影响当前迭代的梯度计算。

简而言之,optimizer.zero_grad() 用于初始化梯度,确保每次迭代都是基于当前 batch 的梯度计算和参数更新,而不会受到之前迭代的影响。
'''

'''

outputs.logits 是模型生成的原始输出,通常是一个三维张量,其中包含了模型对于每个词汇的得分(未经过 softmax 处理)。在语言模型中,这个张量的维度通常是 (batch_size, sequence_length, vocab_size),其中 batch_size 表示批量大小,sequence_length 表示每个序列的长度,vocab_size 表示词汇表的大小。
在生成文本任务中,outputs.logits 的每个元素表示模型在当前位置生成每个词汇的得分。通常,需要对这些得分进行 softmax 处理以获得每个词汇的概率分布,然后根据概率分布进行采样或选择最高概率的词汇作为模型生成的下一个词。

torch.argmax 是 PyTorch 库中的一个函数,用于返回张量中指定维度上的最大值的索引。具体而言,对于一个输入张量,torch.argmax(input, dim=None, keepdim=False) 函数将返回指定维度 dim 上最大值的索引。如果不指定 dim,则默认返回整个张量中最大值的索引。
例如,对于一个形状为 (batch_size, seq_length, vocab_size) 的张量,torch.argmax(outputs.logits, -1) 将返回在 vocab_size 维度上每个位置上的最大值对应的索引,即得分最高的词的索引。

这行代码的作用是将模型输出的 logits (对应每个词的得分)经过 torch.argmax 函数找到得分最高的词的索引,然后使用 tokenizer 对这些索引进行解码,将索引转换为对应的词,并通过 skip_special_tokens=True 参数去除特殊标记(如 [CLS], [SEP] 等)。最终得到的是模型生成的文本内容。
'''

 4.模型保存

#save model
peft_model_id = f"{model_name_path}_{peft_config.peft_type}_{peft_config.peft.task_type}"
model.save_pretrained(peft_model_id)

5.模型训练的其余部分无需更改,当模型训练完成后,保存高效微调的模型权重部分以供模型推理 

#加载微调后的权重
from peft import PeftModel, PeftConfig

config = PeftConfig.from_pretrained(peft_model_id)
##加载基础模型
model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path)
##加载peft模型
model = PeftModel.from_pretrained(model, peft_model_id)

##加载tokenizer
tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
tokennizer.pad_token = tokenizer.eos_token

6.加载微调后的权重文件,并进行推理 

#利用微调后的模型进行推理
##tokenizer编码
inputs = tokenizer(f'{text_column} : {dataset["test"][i]["Tweet text"]} Label : ', return_tensors="pt")

##模型推理
outputs = model.generate(
    input_ids=inputs["input_ids"],
    attention_mask=inputs["attention_mask"],
    max_new_tokens=10,
    eos_token_id=3
)

##tokenizer解码
print(tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1580901.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

CAN和LIN的DB9接口定义

文章目录 前言一、DB9实物及引脚二、LIN DB9三、CAN DB9总结前言 在日常汽车总线测试中,最主要的通信网络就是CAN网络,小伙伴们在测试时,经常会遇到使用DB9插头来测试、录取CAN总线报文,但是DB9插头内有9个插针,哪2个才是CAN-H和CAN-L呢? 一、DB9实物及引脚 DB9 接口是…

杨辉三角形(蓝桥杯,acwing)

题目描述: 下面的图形是著名的杨辉三角形: 如果我们按从上到下、从左到右的顺序把所有数排成一列,可以得到如下数列: 1, 1, 1, 1, 2, 1, 1, 3, 3, 1, 1, 4, 6, 4, 1, ... 给定一个正整数 N,请你输出数列中第一次出现…

实验:基于Red Hat Enterprise Linux系统建立逻辑卷并进行划分

目录 一. 实验目的 二. 实验内容 三. 实验设计描述及实验结果 1. 为虚拟机添加三块大小为5GB的磁盘nvme0n2 nvme0n3 nvme0n4 2. 将三块硬盘转换为物理卷,并将nvme0n2 nvme0n3两pv建立成名为"自己名字_vg“的卷组,并将nvme0n4扩展进该卷组。 LVM管…

os.listdir()bug总结

今天测试出一个神奇的bug,算是教训吧,找了两天不知道问题在哪,最后才发现问题出现在这 原始文件夹显示 os.listdir()结果乱序 import os base_path "./file/"files os.listdir(base_path)print(files)问题原因 解决办法(排序) …

【论文解读】大模型事实性调查(上)

一、简要介绍 本调查探讨了大型语言模型(llm)中的事实性的关键问题。随着llm在不同领域的应用,其输出的可靠性和准确性变得至关重要。论文将“事实性问题”定义为llm产生与既定事实不一致的内容的概率。论文首先深入研究了这些不准确性的含义…

IO-DAY8

使用消息队列去实现2个终端之间的互相聊天 要求:千万不要做出来2个终端之间的消息发送是读一写的&#xff0c;一定要能够做到&#xff0c;一个终端发送n条消息&#xff0c;另一个终端一条消息都不回复 A终端&#xff1a; #include<myhead.h> typedef struct msgbuf {lon…

B02、执行引擎-5

1、前言 1.1、什么是机器码 各种用二进制编码方式表示的指令&#xff0c;叫做机器指令码。开始&#xff0c;人们就用它采编写程序&#xff0c;这就是机器语言。机器语言虽然能够被计算机理解和接受&#xff0c;但和人们的语言差别太大&#xff0c;不易被人们理解和记忆&#x…

基于SSM框架实现的在线心理评测与咨询系统(技术栈 spring+springmvc+mybatis+jsp+jquery+css)

一、项目简介 本项目是一套基于SSM框架实现的在线心理评测与咨询系统&#xff0c;主要针对计算机相关专业的正在做毕设的学生与需要项目实战练习的Java学习者。 包含&#xff1a;项目源码、数据库脚本等&#xff0c;该项目附带全部源码可作为毕设使用。 项目都经过严格调试&am…

小白学Java成长日记特别篇

晚上好&#xff0c;各位小伙伴。今天给大家带来的是Java的输出补充篇&#xff0c;前两篇说了输出和输入的大概&#xff0c;但我没有详细讲它俩&#xff0c;因此这篇文章来详细的聊一聊它俩。那么废话不多说&#xff0c;我们赶紧进入正题。 首先讲一讲这个Java的输出吧。 输出格…

使用MQTT.fx接入新版ONENet(24.4.8)

新版ONENet使用MQTT.fx 模拟接入 目录 新版ONENet使用MQTT.fx 模拟接入开始前的准备创建产品设备获取关键参数 计算签名使用MQTT.fx连接服务器数据流准备与上传数据流准备数据发送与接收 开始前的准备 创建产品 设备下载Token签名工具生成签名 创建产品设备 根据以下内容填写…

windows组播发不出去解决办法

由于开启了虚拟网卡&#xff0c;安装VMWare虚拟化软件&#xff0c;可能会通过虚拟网卡发送组播&#xff0c;需要禁用虚拟化网卡。

Linux网络名称空间的抽象设计以及借鉴意义

Linux作为一个强大的开源操作系统&#x1f427;&#xff0c;其网络虚拟化技术中的核心组件——网络名称空间&#xff08;Network Namespace&#xff09;&#xff0c;是对网络资源的一种高度抽象。网络名称空间允许系统内部存在多个隔离的网络环境&#xff0c;每个环境都有自己的…

数字图像处理与交叉学科中名词的拧巴

特征提取 图像处理——对图像、目标或特征点进行定量描述的方法及过程。 模式识别——对原特征进行特征变换&#xff0c;从高维空间到低维空间映射。 特征向量 模式识别、图像处理——一个观测包括多个变量&#xff0c;样本的多个特征组成特征向量。 线性代数——特征值对应的…

随机过程-BS定理

随机偏微分方程相比普通偏微分方程具有额外的随机项&#xff0c;反映了其描述的现象具有随机性质

请求转发和请求重定向的区别

请求转发(Forward)和请求重定向(Redirect)虽然都是 HTTP 服务器&#xff0c;处理客户端请求时进行(页面)跳转的实现方式&#xff0c;但是二者有以下 5 点不同: 1. 定义不同。 2. 跳转方不同。 3. 数据共享不同。 4.最终 URL 地址不同。 5.代码实现不同。 具体内容如下&…

第1章 MySQL概述

文章目录 第1章 MySQL概述1.1 前言1.2 MySQL安装1.3 常见的指令&#xff08;1&#xff09;MySQL的启动和关闭语句&#xff08;2&#xff09;MySQL的登录语句&#xff08;3&#xff09;MySQL的退出语句&#xff08;4&#xff09;查看MySQL的版本号&#xff08;5&#xff09;查看…

lgwr超时如何判断存储还是cpu问题?(等待事件各种类型和说明及相关查询)

通过awr报告看&#xff1a; 分析&#xff1a; log file parallel write平均等待8毫秒 log file sync平均等待402毫秒 排查&#xff1a; log file sync parallel write lgwr cpu log file parallel write等待少说明存储不慢。 所以&#xff1a;log file sync等待长是因为…

Redis系列之主从复制集群搭建

在上一篇博客&#xff0c;我们已经知道怎么搭建一个redis单机版&#xff0c;这篇博客基于之前的基础&#xff0c;来搭建一个redis主从同步&#xff0c;本博客框架是一主二从&#xff0c;一个主节点&#xff0c;其它两个从节点 实验环境 CentOS7Xshell6XFtp6Redis6.2.2 主从关…

强化学习MPC——(二)

本篇主要介绍马尔科夫决策&#xff08;MDP&#xff09;过程&#xff0c;在介绍MDP之前&#xff0c;还需要对MP&#xff0c;MRP过程进行分析。 什么是马尔科夫&#xff0c;说白了就是带遗忘性质&#xff0c;下一个状态S_t1仅与当前状态有关&#xff0c;而与之前的状态无关。 为…

金山系雄风再显,雷军创建黑灯工厂,中文编程也迎来新突破

金山系被誉为互联网行业的黄埔军校&#xff0c;这一称谓绝非虚言。 求伯君、雷军、蒋涛、傅盛、王峰等金山系的前辈们&#xff0c;不仅是中国科技发展的领军人物&#xff0c;更是以他们的智慧和勇气&#xff0c;引领着国产科技的新方向&#xff0c;书写了一段段充满激情与传奇色…