【深度学习总结_03】使用弱智吧数据微调LLama3+自我认知训练

news2024/10/7 5:47:54

使用弱智吧数据微调LLama3+自我认知训练

  • 使用弱智吧数据微调LLama3+自我认知训练
    • 下载LLama3权重
    • 准备数据集
    • 克隆alpaca-lora仓库
    • 修改finetune.py代码
      • 修改LlamaTokenizer
      • 注释代码
      • 手动安装apex
    • 运行finetune.py
    • 运行generate.py文件
    • 导出Lora模型
    • 自我认知训练

使用弱智吧数据微调LLama3+自我认知训练

参考链接:alpaca-lora
本实验在趋动云进行,该平台注册做完任务可以免费送300元算力。
平台链接为:趋动云

下载LLama3权重

博主已经在趋动云上传了一份LLama3-8b-instruct的权重,用户可以在创建项目后并进入项目后点击模型进行修改。
在这里插入图片描述
我使用的镜像是pytorch 2.1.2。
在这里插入图片描述

准备数据集

我选择的是Better-Ruozhiba的数据,并对其进行了处理,构成如下形式的数据:

[
    {
        "instruction": "爸爸再婚,我是不是就有了个新娘?",
        "input": "",
        "output": "不是的,你有了一个继母。\"新娘\"是指新婚的女方,而你爸爸再婚,他的新婚妻子对你来说是继母。"
    }
]

这是alpaca-lora中输入到LLama的一种数据格式,含有instruction,input,output三个键。

自我认知数据集我选择的是ChatGLM的self_cognition.json,内容示例为:

[
  {
    "instruction": "你好",
    "input": "",
    "output": "您好,我是 <NAME>,一个由 <AUTHOR> 开发的 AI 助手,很高兴认识您。请问我能为您做些什么?"
  },
]

你可以根据自己的需要修改 <NAME><AUTHOR>

克隆alpaca-lora仓库

本次实践是基于alpaca-lora的代码开发的,它是对llama2的一个微调。

git clone https://github.com/tloen/alpaca-lora.git

修改finetune.py代码

首先,需要修改finetune.py的代码,因为原始的代码对最新的一些包不适配。

修改LlamaTokenizer

首先将代码中的LlamaTokenizer换成AutoTokenizer,因为LlamaTokenizer需要你提供的地址中含有tokenizer.model文件,而从官方下载的LLama3的权重里面是没有这个文件的。

如果你没有修改的话,将会报错:

return _sentencepiece.SentencePieceProcessor_LoadFromFile(self, arg)
TypeError: not a string

注释代码

将260行左右的如下代码注释掉:

old_state_dict = model.state_dict
model.state_dict = (
	lambda self, *_, **__: get_peft_model_state_dict(
             self, old_state_dict()
         )
     ).__get__(model, type(model))

如果没有注释的话,最后训练结束后会报错safetensors_rust.SafetensorError: Error while deserializing header: InvalidHeaderDeserialization,同时一些文件无法报错,导致你需要重新跑一次。

手动安装apex

如果你运行finetune.py出现报错:ImportError: cannot import name ‘UnencryptedCookieSessionFactoryConfig’ from ‘pyramid.session’ (unknown location),那么你就需要重新安装apex库,如果没有出现,就需要,手动安装的代码为:

git clone git://github.com/NVIDIA/apex
cd apex
pip install -v --disable-pip-version-check --no-build-isolation --no-cache-dir ./

运行finetune.py

可以新建一个finetune.sh文件,内容为:

python finetune.py \
    --base_model='/gemini/pretrain' \
    --data_path='/gemini/code/datasets/ruozhiba_train_data.json' \
    --cutoff_len=512 \
    --num_epochs=10 \
    --learning_rate 1e-4 \
    --batch_size=32 \
    --micro_batch_size=8 \
    --val_set_size=100 \
    --group_by_length \
    --output_dir='./lora-ruozhi' \
    --lora_target_modules='[q_proj,k_proj,v_proj,o_proj]' \
    --lora_r=16 \
    --lora_alpha 16 \
    --lora_dropout 0.05 \
    --train_on_inputs False \
    --prompt_template_name "alpaca" \

重要参数解释如下:

  • base_model:你需要微调的模型地址,这里我填的的LLama3的本地权重地址,如果你的网速好的话,也可以填写huggingface的地址
  • data_path:数据集的位置,其中ruozhiba_train_data.json是弱智吧的数据
  • cutoff_len:处理文本的最大长度,超过这个会截断
  • num_epochs:训练的epoch数目
  • learning_rate:学习率
  • batch_size:这个不是加载数据集时的batch size大小,它决定的是梯度累积的间隔
gradient_accumulation_steps = batch_size // micro_batch_size
  • micro_batch_size:这个才是加载数据集时的batch size大小
  • val_set_size:验证集大小
  • lora_target_modules:我们使用的是Lora微调,这里指定需要微调的模块
  • lora_r:Lora微调中rank的大小
  • lora_alpha:Lora的参数
  • train_on_inputs:输入文本是否参与训练
  • prompt_template_name:模板的名词,在templates文件夹中有

其中模板的文件内容为:

{
    "description": "Template used by Alpaca-LoRA.",
    "prompt_input": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n",
    "prompt_no_input": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n",
    "response_split": "### Response:"    
}

在训练过程中,数据集里面的内容会填充到prompt_input中对应的内容,从而输入到模型进行训练。你也可以在templates下新建其他的模板文件,然后更改prompt_template_name参数即可。

除此之外,你可能还需要在finetune.py对应位置修改模型保存的间隔等参数,如下:

trainer = transformers.Trainer(
        model=model,
        train_dataset=train_data,
        eval_dataset=val_data,
        args=transformers.TrainingArguments(
            per_device_train_batch_size=micro_batch_size,
            gradient_accumulation_steps=gradient_accumulation_steps,
            warmup_steps=80,
            num_train_epochs=num_epochs,
            learning_rate=learning_rate,
            fp16=True,
            logging_steps=10,
            optim="adamw_torch",
            evaluation_strategy="steps" if val_set_size > 0 else "no",
            save_strategy="steps",
            eval_steps=50 if val_set_size > 0 else None,
            save_steps=50,
            output_dir=output_dir,
            save_total_limit=3,
            load_best_model_at_end=True if val_set_size > 0 else False,
            ddp_find_unused_parameters=False if ddp else None,
            group_by_length=group_by_length,
            report_to="wandb" if use_wandb else None,
            run_name=wandb_run_name if use_wandb else None,
        ),
        data_collator=transformers.DataCollatorForSeq2Seq(
            tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
        ),
    )

其中eval_steps表示验证的步数间隔,save_steps表示保存的步数间隔,warmup_steps表示学习率的warm up步数。

运行finetune.sh,如下:

bash finetune.sh

最后的运行结果为:
在这里插入图片描述
上面的450就算总共跑的steps。
最终的lora权重保存在lora-ruozhi文件夹下面,如下:
在这里插入图片描述

运行generate.py文件

得到微调模型后,如果你想看看它的效果,运行generate.py文件,它通过gradio构建可视化界面,修改前面一部分代码即可:

def main(
    load_8bit: bool = False,
    base_model: str = "",
    lora_weights: str = "/gemini/code/Project/alpaca-lora/lora-ruozhi",
    prompt_template: str = "",  # The prompt template to use, will default to alpaca.
    server_name: str = "0.0.0.0",  # Allows to listen on all interfaces by providing '0.
    share_gradio: bool = True,
):
    # 改成base model地址
    # base_model = base_model or os.environ.get("BASE_MODEL", "")
    base_model = "/gemini/pretrain"
    assert (
        base_model
    ), "Please specify a --base_model, e.g. --base_model='huggyllama/llama-7b'"

    prompter = Prompter(prompt_template)
    tokenizer = AutoTokenizer.from_pretrained(base_model)

主要修改lora_weightsbase_modellora_weights就是上面一步保存的地址。
除此之外,还要将代码中的LlamaTokenizer换成AutoTokenizer,不然会报错。
share_gradio我也修改了,让其能够生成一个共享链接,方便调试。

最后运行python generate.py,我选取了几个问题进行测试,都没有中套:
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
而原始的LLama3的结果为:
在这里插入图片描述
上面这个问题它用英文回答的,不过并没有中套。

在这里插入图片描述
这个问题LLama3翻车了,没有反应过来,说明我们的微调还是有作用的。

导出Lora模型

运行python export_hf_checkpoint.pyexport_hf_checkpoint.py的需要修改的地方为:

# BASE_MODEL = os.environ.get("BASE_MODEL", None)
# 1.改成要微调的模型地址
BASE_MODEL = "/gemini/pretrain"

# 2.LlamaTokenizer换成AutoTokenizer,记得导入
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)

# 3.修改为lora微调保存的地址,就是上面微调部分部分保存的地址
lora_model = PeftModel.from_pretrained(
    base_model,
    # 修改为lora微调保存的地址
    "lora-ruozhi",
    device_map={"": "cpu"},
    torch_dtype=torch.float16,
)

# 4. 改成微调模型的输出位置
## max_shard_size表示每个文件的最大内存,因此最终的文件中会有多个safetensor文件
LlamaForCausalLM.save_pretrained(
    base_model, 
    # 改成你要输出的位置
    "./saves/lora-ruozhi", 
    state_dict=deloreanized_sd, max_shard_size="2048MB"
)

最终输出的文件夹内容为:
在这里插入图片描述

自我认知训练

修改finetune.sh,将base_model改成导出Lora模型阶段的输出地址,data_path改成自我认知数据集的地址,具体如下:

python finetune.py \
    --base_model='/gemini/code/Project/alpaca-lora/saves/lora-ruozhi' \
    --data_path='/gemini/code/datasets/self_cognition.json' \
    --cutoff_len=512 \
    --num_epochs=10 \
    --learning_rate 1e-4 \
    --batch_size=32 \
    --micro_batch_size=8 \
    --val_set_size=100 \
    --group_by_length \
    --output_dir='./lora-ruozhi-self' \
    --lora_target_modules='[q_proj,k_proj,v_proj,o_proj]' \
    --lora_r=16 \
    --lora_alpha 16 \
    --lora_dropout 0.05 \
    --train_on_inputs False \
    --prompt_template_name "alpaca" \

注意根据你自己的情况修改output_dirfinetune.py中的保存间隔等参数。

然后运行python generate.py,此时文件前面部分的地址改成为:

def main(
    load_8bit: bool = False,
    base_model: str = "",
    lora_weights: str = "/gemini/code/Project/alpaca-lora/lora-ruozhi-self",
    prompt_template: str = "",  # The prompt template to use, will default to alpaca.
    server_name: str = "0.0.0.0",  # Allows to listen on all interfaces by providing '0.
    share_gradio: bool = True,
):
    # 改成base model地址
    # base_model = base_model or os.environ.get("BASE_MODEL", "")
    base_model = "/gemini/code/Project/alpaca-lora/saves/lora-ruozhi"
    assert (
        base_model
    ), "Please specify a --base_model, e.g. --base_model='huggyllama/llama-7b'"

    prompter = Prompter(prompt_template)
    tokenizer = AutoTokenizer.from_pretrained(base_model)

因此这次自我认知训练是在弱智吧训练的模型基础上微调的,因此base_model改成了导出的Lora模型。

模型的输出结果为:
在这里插入图片描述
在这里插入图片描述
使用弱智吧微调得到的能力也没有消失:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1864498.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

每日一学(1)

目录 1、ConCurrentHashMap为什么不允许key为null&#xff1f; 2、ThreadLocal会出现内存泄露吗&#xff1f; 3、AQS理解 4、lock 和 synchronized的区别 1、ConCurrentHashMap为什么不允许key为null&#xff1f; 底层 putVal方法 中 如果key || value为空 抛出…

Unity编辑器工具---版本控制与自动化打包工具

Unity - 特殊文件夹【作用与是否会被打包到build中】 Unity编辑器工具—版本控制与自动化打包工具&#xff1a; 面板显示&#xff1a;工具包含一个面板&#xff0c;用于展示软件的不同版本信息。版本信息&#xff1a;面板上显示主版本号、当前版本号和子版本号。版本控制功能…

Socket——向FTP服务器发送消息并获得响应

1、简介 Socket&#xff08;套接字&#xff09;是网络编程中用于描述IP地址和端口的一个抽象概念&#xff0c;通过它可以实现不同主机间的通信。套接字可以分为几种不同的类型&#xff0c;每种类型对应不同的协议和传输模式。 1.1、基本概念 IP地址&#xff1a;用于标识网络…

智能语音抽油烟机:置入WTK6900L离线语音识别芯片 掌控厨房新风尚

一、抽油烟机语音识别芯片开发背景 在繁忙的现代生活中&#xff0c;人们对于家居生活的便捷性和舒适性要求越来越高。传统的抽油烟机操作方式往往需要用户手动调节风速、开关等功能&#xff0c;不仅操作繁琐&#xff0c;而且在烹饪过程中容易分散注意力&#xff0c;增加安全隐…

梅雨季要祛湿,更要养阳气!用好这2招,补足阳气,祛湿排寒,助你体质越来越好~

梅雨季祛湿正当时&#xff0c;但是很多人疑惑不断祛湿&#xff0c;可为什么湿气一直源源不断呢&#xff1f;除了不良生活方式没有改变外&#xff0c;很多人祛湿只是解决“标”的问题&#xff0c;没有解决“本”的问题。 中医有一句话叫“无阳难以化湿”&#xff01; 我们可以简…

前端 CSS 经典:模拟 material 文本框

效果 思路 定义三个元素&#xff0c;文本框&#xff0c;下划线&#xff0c;占位文字。input 聚焦时通过 ~ 选中兄弟元素&#xff0c;利用 required 属性 css 中的 valid 验证&#xff0c;判断 input 中是否有输入。写入过渡效果。 实现代码 <!DOCTYPE html> <htm…

基于 JuiceFS 构建高校 AI 存储方案:高并发、系统稳定、运维简单

中山大学的 iSEE 实验室&#xff08;Intelligence Science and System) Lab&#xff09;在进行深度学习任务时&#xff0c;需要处理大量小文件读取。在高并发读写场景下&#xff0c;原先使用的 NFS 性能较低&#xff0c;常在高峰期导致数据节点卡死。此外&#xff0c;NFS 系统的…

203.回溯算法:N皇后(力扣)

class Solution { public:vector<vector<string>> result; // 用于存储所有合法的 N 皇后放置方案// 判断当前位置 (row, col) 是否可以放置皇后bool isValid(int row, int col, vector<string>& chess, int n) {// 检查当前列是否有皇后for (int i 0;…

锐起RDV5高性能云桌面

锐起是上海锐起信息技术有限公司旗下品牌。该公司创立于 2001 年&#xff0c;是桌面虚拟化产品和解决方案提供商&#xff0c;专注于桌面管理系统和私有云存储系统的系列软件产品研发&#xff0c;致力于简化 IT 管理、增强系统安全&#xff0c;提供简单、易用、稳定、安全的产品…

视觉与运动控制6

基于驱动器的控制功能 驱动器的系统性能和运算能力有限需要单独的运动控制器。 V/F恒压频比控制 开环控制方法&#xff0c;应用最广泛、最简单&#xff0c;只需要电机数据即可。适用于控制精度和动态响应要求不高的应用。控制原理&#xff1a;保持点击内磁通量恒定&#xff…

Android10 Settings系列(六)Settings中toolbar 的基本流程,和Activity如何关联,这可能是比较详细的分析

一、前言 写在前面:一个快捷栏,音量浮窗快捷进入设置界面,点击左上角返回键拉起设置首页问题引发的思考和解决方法 事情的起因是测试报了一个问题。在Android9的一个设备在点击音量键时,在弹出的弹框中,点击设置图标快速进入音量设置中,点击左上角返回按钮是,退出当前界…

【深度学习】基于因果表示学习的CITRIS模型原理和实验

1.引言 1.1.本文的主要内容 理解动态系统中的潜在因果因素&#xff0c;对于智能代理在复杂环境中进行有效推理至关重要。本文将深入介绍CITRIS&#xff0c;这是一种基于变分自编码器&#xff08;VAE&#xff09;的框架&#xff0c;它能够从时间序列图像中提取并学习因果表示&…

WebSocket 连接失败的原因及解决方法

WebSocket 目前已经成为了一项极为重要的技术&#xff0c;其允许客户端和服务器之间进行实时、全双工的通信。然而&#xff0c;在实际项目中&#xff0c;开发者时常会遇到 WebSocket 连接失败的情况。这不仅影响了用户体验&#xff0c;还可能导致不可预见的系统错误或数据丢失。…

PMP®项目管理国际认证——管理者的必备证书

北京青蓝智慧科技从2010年起专注于PMP认证的考前培训。 凭借十余年的研究与多位行业专家的合作探讨&#xff0c;我们推出了高效的“PMP通关四步法”。 这一方法经过多年实践检验&#xff0c;显著提升了考生的通过率&#xff0c;为众多学员带来了实质的学习成果。 第一阶段&am…

超低排放标准

据朗观视觉小编了解发现&#xff0c;超低排放标准作为衡量一个行业或企业环保水平的重要指标&#xff0c;越来越受到社会各界的关注。本文将深入探讨超低排放标准的内涵、实施意义以及未来展望。 一、超低排放标准的定义 超低排放标准&#xff0c;是指在特定工业生产过程中&am…

Mac 微信能上网但浏览器打不开网页

文章目录 推荐 DNSMac 设置 DNS 推荐 DNS 名称首选 DNS备用 DNSGoogle8.8.8.88.8.4.4114 DNS114.114.114.114114.114.115.115阿里223.5.5.5百度180.76.76.76腾讯119.29.29.29电信101.226.4.6联通123.125.81.6移动101.226.4.6铁通101.226.4.68福建电信218.85.152.99218.85.157.…

【STM32】USART串口通讯

1.USART简介 STM32芯片具有多个USART外设用于串口通讯&#xff0c;它是 Universal Synchronous Asynchronous Receiver and Transmitter的缩写&#xff0c; 即通用同步异步收发器可以灵活地与外部设备进行全双工数据交换。有别于USART&#xff0c; 它还有具有UART外设(Univers…

C#1.0-11.0所有历史版本主要特性总结

文章目录 前言名词解释主要版本一览表各版本主要特性一句话总结 C# 1.0 (Visual Studio 2002, .Net Framework 1.0)C# 2.0 (Visual Studio 2005, .Net Framework 2.0)C# 3.0 (Visual Studio 2008, .Net Framework 3.0)C# 4.0 (Visual Studio 2010, .Net Framework 4)C# 5.0 (V…

抖音集团基于 Apache Doris 的实时数据仓库实践

作者&#xff1a;字节跳动数据平台 在直播、电商等业务场景中存在着大量实时数据&#xff0c;这些数据对业务发展至关重要。而在处理实时数据时&#xff0c;我们也遇到了诸多挑战&#xff0c;比如实时数据开发门槛高、运维成本高以及资源浪费等。 此外&#xff0c;实时数据处…

掌握Scrum:敏捷开发中的短期迭代与定期会议

目录 前言1. Scrum概述1.1 什么是Scrum1.2 Scrum的三大支柱 2. 短期迭代&#xff08;Sprint&#xff09;2.1 Sprint规划2.1.1 确定Sprint目标2.1.2 创建Sprint待办列表 2.2 Sprint执行2.2.1 每日站会 2.3 Sprint回顾2.3.1 Sprint评审2.3.2 Sprint回顾 3. 定期会议3.1 产品待办列…