[大模型]# Yi-6B-Chat Lora 微调

news2025/1/11 10:12:11

Yi-6B-Chat Lora 微调

概述

本节我们介绍如何基于 transformers、peft 等框架,对 Yi-6B-Chat 模型进行 Lora 微调。Lora 是一种高效微调方法,深入了解其原理可参见博客:知乎|深入浅出Lora。

本节所讲述的代码脚本在同级目录 04-Yi-6B-Chat Lora 微调 下,运行该脚本来执行微调过程,但注意,本文代码未使用分布式框架,微调 Yi-6B-Chat 模型至少需要 20G 及以上的显存,且需要修改脚本文件中的模型路径和数据集路径。

环境配置

在完成基本环境配置和本地模型部署的情况下(本教程中使用的模型路径是 /root/autodl-tmp/01ai/Yi-6B-Chat ),你还需要安装一些第三方库,可以使用以下命令:

pip install transformers==4.35.2
pip install peft==0.4.0
pip install datasets==2.10.1
pip install accelerate==0.20.3
pip install tiktoken
pip install transformers_stream_generator

在本节教程里,我们将微调数据集放置在根目录 /dataset。

指令集构建

LLM 的微调一般指指令微调过程。所谓指令微调,是说我们使用的微调数据形如:

{
    "instrution":"回答以下用户问题,仅输出答案。",
    "input":"1+1等于几?",
    "output":"2"
}

其中,instruction 是用户指令,告知模型其需要完成的任务;input 是用户输入,是完成用户指令所必须的输入内容;output 是模型应该给出的输出。

即我们的核心训练目标是让模型具有理解并遵循用户指令的能力。因此,在指令集构建时,我们应针对我们的目标任务,针对性构建任务指令集。例如,在本节我们使用由笔者合作开源的 Chat-甄嬛 项目作为示例,我们的目标是构建一个能够模拟甄嬛对话风格的个性化 LLM,因此我们构造的指令形如:##

{
    "instruction": "现在你要扮演皇帝身边的女人--甄嬛",
    "input":"你是谁?",
    "output":"家父是大理寺少卿甄远道。"
}

我们所构造的全部指令数据集在根目录下。

数据格式化

Lora 训练的数据是需要经过格式化、编码之后再输入给模型进行训练的,如果是熟悉 Pytorch 模型训练流程的同学会知道,我们一般需要将输入文本编码为 input_ids,将输出文本编码为 labels,编码之后的结果都是多维的向量。我们首先定义一个预处理函数,这个函数用于对每一个样本,编码其输入、输出文本并返回一个编码后的字典:

def process_func(example):
    MAX_LENGTH = 384    # Llama分词器会将一个中文字切分为多个token,因此需要放开一些最大长度,保证数据的完整性
    input_ids, attention_mask, labels = [], [], []
    instruction = tokenizer("\n".join(["<|im_start|>system", "现在你要扮演皇帝身边的女人--甄嬛.<|im_end|>" + "\n<|im_start|>user\n" + example["instruction"] + example["input"] + "<|im_end|>\n"]).strip(), add_special_tokens=False)  # add_special_tokens 不在开头加 special_tokens
    response = tokenizer("<|im_start|>assistant\n" + example["output"] + "<|im_end|>\n", add_special_tokens=False)
    input_ids = instruction["input_ids"] + response["input_ids"] + [tokenizer.pad_token_id]
    attention_mask = instruction["attention_mask"] + response["attention_mask"] + [1]  # 因为eos token咱们也是要关注的,所以补充为1
    labels = [-100] * len(instruction["input_ids"]) + response["input_ids"] + [tokenizer.pad_token_id]  # Yi-6B的构造就是这样的
    if len(input_ids) > MAX_LENGTH:  # 做一个截断
        input_ids = input_ids[:MAX_LENGTH]
        attention_mask = attention_mask[:MAX_LENGTH]
        labels = labels[:MAX_LENGTH]
    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": labels
    }

然后加载我们的数据集,用上面定义的函数处理数据

# 将JSON文件转换为CSV文件
import pandas as pd
from datasets import Dataset
df = pd.read_json('/root/dataset/huanhuan.json')
ds = Dataset.from_pandas(df)

tokenized_id = ds.map(process_func, remove_columns=ds.column_names)

经过格式化的数据,也就是送入模型的每一条数据,都是一个字典,包含了 input_idsattention_masklabels 三个键值对,其中 input_ids 是输入文本的编码,attention_mask 是输入文本的 attention mask,labels 是输出文本的编码。decode之后应该是这样的:

<|im_start|>system
现在你要扮演皇帝身边的女人--甄嬛.<|im_end|>
<|im_start|>user
小姐,别的秀女都在求中选,唯有咱们小姐想被撂牌子,菩萨一定记得真真儿的——<|im_end|>
<|im_start|>assistant
嘘——都说许愿说破是不灵的。<|im_end|>
<|endoftext|>

我们可以输出一条文本观察一下:

print(tokenizer.decode(tokenized_id[0]['input_ids']))
print(tokenizer.decode(list(filter(lambda x: x != -100, tokenized_id[0]["labels"]))))

输出结果如下图所示:

在这里插入图片描述

加载tokenizer和半精度模型

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('01ai/Yi-6B-Chat', use_fast=False, trust_remote_code=True)

# 模型以半精度形式加载,如果你的显卡比较新的话,可以用torch.bfolat形式加载
model = AutoModelForCausalLM.from_pretrained('01ai/Yi-6B-Chat', trust_remote_code=True, torch_dtype=torch.half, device_map="auto")

定义LoraConfig

LoraConfig这个类中可以设置很多参数,但主要的参数没多少,简单讲一讲,感兴趣的同学可以直接看源码。

  • task_type:模型类型
  • target_modules:需要训练的模型层的名字,主要就是attention部分的层,不同的模型对应的层的名字不同,可以传入数组,也可以字符串,也可以正则表达式。
  • rlora的秩,具体可以看Lora原理
  • lora_alphaLora alaph,具体作用参见 Lora 原理
from peft import LoraConfig, TaskType, get_peft_model
config = LoraConfig(
    task_type=TaskType.CAUSAL_LM, 
    target_modules=["q_attn", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    inference_mode=False, # 训练模式
    r=8, # Lora 秩
    lora_alpha=32, # Lora alaph,具体作用参见 Lora 原理
    lora_dropout=0.1# Dropout 比例
)

训练模型

首先,使用get_peft_model函数将基础模型和peft_config包装起来,以创建PeftModel。要了解模型中可训练参数的数量,可以使用print_trainable_parameters方法。

model = get_peft_model(model, config)
model.enable_input_require_grads() # 开启梯度检查点时,要执行该方法
model.print_trainable_parameters()

接下来,我们自定义 TrainingArguments 参数

TrainingArguments这个类的源码也介绍了每个参数的具体作用,当然大家可以来自行探索,这里就简单说几个常用的。

  • output_dir:模型的输出路径
  • per_device_train_batch_size:顾名思义 batch_size
  • gradient_accumulation_steps: 梯度累加,如果你的显存比较小,那可以把 batch_size 设置小一点,梯度累加增大一些。
  • logging_steps:多少步,输出一次log
  • num_train_epochs:顾名思义 epoch
  • gradient_checkpointing:梯度检查,这个一旦开启,模型就必须执行model.enable_input_require_grads(),这个原理大家可以自行探索,这里就不细说了。
from transformers import DataCollatorForSeq2Seq, TrainingArguments, Trainer
args = TrainingArguments(
    output_dir="./output/Yi-6B",
    per_device_train_batch_size=8,
    gradient_accumulation_steps=2,
    logging_steps=10,
    num_train_epochs=3,
    gradient_checkpointing=True,
    save_steps=100,
    learning_rate=1e-4,
    save_on_each_node=True
)

最后,使用Traniner训练模型

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=tokenized_id,
    data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True),
)
trainer.train()

模型训练完成后,会输出如下图所示的信息:
在这里插入图片描述

模型推理

下载好的模型被保存在了 ./output/Yi-6B 目录下,如果想要从头加载微调好的模型,需要执行下面的代码

from transformers import AutoModelForSeq2SeqLM
from peft import PeftModel, PeftConfig

peft_model_id = "output/Yi-6B/checkpoint-600"  # 这里我训练出效果最好的一版是 checkpoint-600,所以调用了这个,大家可以根据自己情况选择
config = PeftConfig.from_pretrained(peft_model_id)
model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path)
model = PeftModel.from_pretrained(model, peft_model_id)

然后使用以下代码进行模型推理:

model.eval()
input = tokenizer("<|im_start|>system\n现在你要扮演皇帝身边的女人--甄嬛.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n".format("你是谁?", "").strip() + "\nassistant\n ", return_tensors="pt").to(model.device)

max_length = 512

outputs = model.generate(
    **input,
    max_length=max_length,
    eos_token_id=7,
    do_sample=True,
    repetition_penalty=1.3,
    no_repeat_ngram_size=5,
    temperature=0.1,
    top_k=40,
    top_p=0.8,
)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1591133.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【教学类-52-04】20240412动物数独(4宫格)空1-空15

作品展示 背景需求&#xff1a; 【教学类-52-03】20240412动物数独&#xff08;4宫格&#xff09;难度1-9 打印版-CSDN博客文章浏览阅读603次&#xff0c;点赞20次&#xff0c;收藏8次。【教学类-52-03】20240412动物数独&#xff08;4宫格&#xff09;难度1-9 打印版https://…

Razzashi Raptor

拉扎什迅猛龙 Razzashi Raptor 95000金&#xff08;游戏币&#xff09;比老虎便宜多了&#xff0c;捡漏啊 为啥我开团都不出&#xff0c;很生气&#xff0c;去打架&#xff01;&#xff01;

【题目】【信息安全管理与评估】2022年国赛高职组“信息安全管理与评估”赛项样题2

【题目】【信息安全管理与评估】2022年国赛高职组“信息安全管理与评估”赛项样题2 信息安全管理与评估 网络系统管理 网络搭建与应用 云计算 软件测试 移动应用开发 任务书&#xff0c;赛题&#xff0c;解析等资料&#xff0c;知识点培训服务 添加博主wx&#xff1a;liuliu548…

在线药房数据惨遭Ransomhub窃取,亚信安全发布《勒索家族和勒索事件监控报告》

本周态势快速感知 本周全球共监测到勒索事件119起&#xff0c;与上周相比勒索事件有所增长。 本周Blacksuit是影响最严重的勒索家族&#xff0c;Ransomhub和Blackbasta恶意家族紧随其后&#xff0c;从整体上看Lockbit3.0依旧是影响最严重的勒索家族&#xff0c;需要注意防范。…

《五》QListWidget列表框

QListWidgetQListWidget和QListWidgetItem QListWidget 是qt中的列表框控件&#xff0c;它用于显示多个列表项&#xff0c;列表项对应的类是QListWidgetItem. QListWidget列表框的创建 QListWidget 类的继承关系如下&#xff1a; QListWidget -> QListView -> QAbs…

SonarQube 9.9.4 LTS社区版安装

目标 安装个SonarQube社区版. 安装SonarQube9.9.4 LTS社区版 https://binaries.sonarsource.com/Distribution/sonarqube/sonarqube-9.9.4.87374.zip # 切换到安装目录 cd /opt # 下载安装包 sudo wget https://binaries.sonarsource.com/Distribution/sonarqube/sonarqube…

中国移动校园招聘相关笔试题整理

目录 公司文化文档更新时间 公司文化 24年的 改正地方&#xff1a; 世界电信日&#xff1a;1969年5月17日18年的企业文化题 改正地方&#xff1a; 中国移动企业发展定位是&#xff1a;世界一流企业 中国移动的企业文化体系主要由核心价值观、使命和愿景三部分构成。 核心价值观…

基于Springboot的自习室预订系统

基于SpringbootVue的自习室预订系统的设计与实现 开发语言&#xff1a;Java数据库&#xff1a;MySQL技术&#xff1a;SpringbootMybatis工具&#xff1a;IDEA、Maven、Navicat 系统展示 用户登录页 网站首页 公告信息 留言反馈 后台管理 学生信息管理 公告信息管理 留言…

004Node.js常用快捷键

1.常用的终端命令&#xff1a; &#xff08;1&#xff09;del 文件名&#xff1a; 删除文件 &#xff08;2&#xff09;ipconfig: 查看IP命令 &#xff08;3&#xff09;mkdir 目录名 &#xff1a;在当前目录新建指定目录 &#xff08;4&#xff09;rd 目录名&#xff1a;在当前…

面经:Cassandra分布式NoSQL数据库深度解读

作为一位热衷于分享技术知识的博主&#xff0c;我深知在当今大数据时代&#xff0c;掌握分布式数据库尤其是Apache Cassandra的原理与实践对于提升个人技能和应对面试挑战的重要性。本篇博客将从我的面试经验出发&#xff0c;结合对Cassandra核心特性的理解&#xff0c;深入探讨…

万兆以太网MAC设计(3)MAC_RX模块添加CRC

文章目录 前言一、并行CRC处理二、添加CRC处理的MAC_RX模块三、总结 前言 上文介绍的MAC_RX模块实现了接受字节对齐的功能&#xff0c;但是尾端存在4字节CRC校验未处理。 一、并行CRC处理 前面在千兆以太网里对CRC代码和使用进行了介绍&#xff0c;千兆里面数据是一个一个by…

c++24.4.13-const修饰指针

1、const修饰指针-常量指针 2、const修饰常量-指针常量 3、const既修饰指针又修饰常量 示例

使用yolov8实现自动车牌识别(教程+代码)

该项目利用了一个被标记为“YOLOv8”的目标检测模型&#xff0c;专门针对车牌识别任务进行训练和优化。整个系统通常分为以下几个核心步骤&#xff1a; 数据准备&#xff1a; 收集包含车牌的大量图片&#xff0c;并精确地标记车牌的位置和文本信息。数据集可能包含各种环境下的…

这家动画公司,女神表情灵动秒杀90%的国漫女角色!

当3D国漫市场逐渐加入“内卷”的行列&#xff0c;从大的底层创作引擎UE的运用迭代&#xff0c;到细节的人物动捕、面捕技术的实际结合&#xff0c;在这场内卷的百舸争流中&#xff0c;涌现出一家家风格各异的头部国漫制作公司&#xff1a;有整体偏写实风格的原力动画&#xff0…

Vue3——html-doc-ja(html导出为word的js库)

一、下载 官方地址 html-doc-js - npm npm install html-doc-js 二、使用方法 // 使用页面中引入 import exportWord from html-doc-js// 配置项以及实现下载方法 const wrap document.getElementById(test)const config {document:document, //默认当前文档的document…

C++类和对象中上篇

1.类的6个默认成员函数 如果一个类中什么成员都没有&#xff0c;那就简称他为空类。 空类中真的什么都没有吗&#xff1f;并不是&#xff0c;任何类在什么都不写时&#xff0c;编译器会自动生成以下6个默认成员函数。 默认成员函数&#xff1a;用户没有显式实现&#xff0c;…

二、Maven安装

Maven安装 一、Centos7.9安装1.下载2.安装3.设置国内镜像4.设置maven安装路径 一、Centos7.9安装 1.下载 第一种&#xff1a;官网下载最新版本&#xff1a;http://maven.apache.org/download.cgi第二种&#xff1a;其他版本下载&#xff1a;https://archive.apache.org/dist/…

Presto Player 2.0 – 引人入胜的视频播放列表

Presto Player 2.0 引入了一项令人惊叹的新功能&#xff1a;视频播放列表。 将其与类似 Netflix 的新体验相结合&#xff0c;您将发现一款流畅的视频播放器&#xff0c;其功能在市场上任何其他工具中都找不到。 让我们看看 Presto Player 2.0 如何将您的内容提升到新的参与度…

Python 正则表达式模块使用

目录 1、匹配单个字符 2、匹配多个字符 3、匹配开头结尾 4、匹配分组 说明&#xff1a;在Python中需要通过正则表达式对字符串进行匹配的时候&#xff0c;可以使用re模块 表达式&#xff1a;re.match(正则表达式&#xff0c; 要匹配的字符串) 有返回值说明匹配成功&#x…

服务器数据恢复—不同型号服务器RAID5数据恢复策略有何不同?

RAID5作为应用最广泛的raid阵列级别之一&#xff0c;在不同型号服务器中的RAID5出现故障后&#xff0c;处理方法也不同。 RAID5阵列级别是无独立校验磁盘的奇偶校验磁盘阵列&#xff0c;采用数据分块和独立存取技术&#xff0c;能在同一磁盘上并行处理多个访问请求&#xff0c;…