CV算法工程师的LLM日志(1)微调技术——即插即用的neft-Tune【原理代码】

news2024/9/25 9:39:56

CV算法的LLM领域日志

目前维护的CV方向开源项目暂时暂停,原因是现在在做LLM方向的研发工作,所以需要时间消化前沿技术和总结经验,最近看到了一个非常简单的LLM训练Trick 分享一下,后续会逐渐把自己使用的一些LLM范式技术原理和代码分享。


文章目录

  • CV算法的LLM领域日志
  • Neft-tune
  • 一、动机
  • 二、代码
    • 1.引入真实上下文的代码
  • 总结


Neft-tune

Noisy embeddings improve instruction finetune,官方文章点此,简单来说,这是一种嵌入噪声的微调方法,已经被HuggingFace收录进了TRL库,只要import再加一行代码就能调用。

from datasets import load_dataset
from trl import SFTTrainer
dataset = load_dataset("your data path ", split=“train”)
trainer = SFTTrainer(
“facebook/opt-350m”,
train_dataset=dataset,
dataset_text_field=“text”,
max_seq_length=512,
)
trainer.train()

NEFT不仅操作简便,而且没有显著的成本增加,作者称看起来是个“免费的午餐”。


一、动机

在做LLM的微调时候,往往都容易在finetune数据集上过拟合,作者目的就是希望在训练阶段向嵌入层中加入噪声的方式作为一种正则化方式来改善这个情况,从而提高性能,仅针对训练阶段,所以非常适合微调工程,下面是算法逻辑:

1.分词器Embedding。
2.随机生成一个噪声, 然后,系统会随机生成一个噪声向量,并根据给定的超参数进行缩放。
3.缩放后的向量会和embedding相加,在训练的embedding层的forward过程中存在
4. 训练每轮迭代1-3步骤。

在这里插入图片描述
在此Trick使用后,作者团队在LLAMA2-7B和Mistral-7B上分别实现了翻倍和25%的提升,按原话说至少提升10%的"免费午餐”,且进一步验证在文本生成任务的质和量上也Work.

二、代码

1.引入真实上下文的代码

NEFT的代码很简单,相当于是对模型的embedding层进行了一次改动和额外引入噪声系数的超参数(默认为5),实际在使用中发现,NEFT的改动是在模型加载后一个结构的后调整,这样在代码逻辑中非常容易实现,不会影响其他上下文的程序逻辑,下面给出一个模拟实际上下文的代码和注释参考:

代码如下(示例):

import transformers
from transformers import Trainer#, GPTQConfig
model = transformers.AutoModelForCausalLM.from_pretrained(
        model_args.model_name_or_path,
        config=config,
        cache_dir=training_args.cache_dir,
        device_map=device_map,
        trust_remote_code=True,
      #  quantization_config=GPTQConfig(
          #  bits=4, disable_exllama=True
       # )
        if training_args.use_lora and lora_args.q_lora
        else None,
    )
tokenizer = transformers.AutoTokenizer.from_pretrained(
        model_args.model_name_or_path,
        cache_dir=training_args.cache_dir,
        model_max_length=training_args.model_max_length,
        padding_side="right",
        use_fast=False,
        trust_remote_code=True,
    )
##首先进行模型和分词器的官方库加载即可
##加入NEFT的模型结构设置即可,下面是核心代码,给定一个噪声强度系数默认值为5

 neft_alpha=5
 print("strat neft!")
 if neft_alpha > 1e-6:
      #  如果模型处于训练模式,则对嵌入矩阵的每个元素加上一个服从均匀分布的随机噪声,噪声范围由finetuning_args.neft_alpha和嵌入矩阵维度决定,在训练模式下,首先计算出噪声的幅度因子(mag_norm),然后torch.zeros_like(embeddings).uniform_(-mag_norm, mag_norm)生成与原始嵌入向量相同形状的张量,该张量的元素取值范围在[-mag_norm, mag_norm]之间,再将该张量加到原始嵌入向量上。最后,函数返回添加了噪声的嵌入向量。核心思路是这个写法改动了模型原始input embedding层的FORWARD操作。
      
        input_embed = model.get_input_embeddings()
        if isinstance(input_embed, torch.nn.Embedding):
            def noisy_forward(self: torch.nn.Embedding, x: torch.Tensor) -> torch.Tensor:
                embeddings = input_embed.__class__.forward(self, x)
                if self.training:
                    dims = self.num_embeddings * self.embedding_dim
                    mag_norm = neft_alpha / (dims ** 0.5)
   	                embeddings+=torch.zeros_like(embeddings).uniform_(-mag_norm, mag_norm)
                return embeddings

			input_embed.forward = noisy_forward(input_embed)
            print("Using noisy embedding with alpha={:.2f}".format(neft_alpha))
        else:
            print("Input embeddings are not normal nn.Embedding, cannot transform into noisy embedding.")


这样模型在训练阶段输入的embedding层在前向传播中会执行我们设定好的Noise embedding funetune方法,推理中则按照原始结构进行,后面可以接入我们常用的lora等微调模型的方法。

总结

后面也会对LORA\QLORA\LONGLORA等微调技术以及更多原理方法逐一解析分享。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1131101.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

部分背包问题细节(贪心)

有一种情况是,背包可以把金币全部拿走: 如果num小于0则返回值

重症医学科常用评估量表汇总,建议收藏!

根据重症医学科医生的量表使用情况,笔者整理了10个重症医学科常用量表,可在线评测直接出结果,可转发使用,可生成二维码使用,可创建项目进行数据管理,有需要的小伙伴赶紧收藏! 简明急性生理功能评…

Rocksdb LSM Tree Compaction策略

RocksDB读写简介 直接画图说明。这张图取自Flink PMC大佬Stefan Richter在Flink Forward 2018演讲的PPT,笔者重画了一下。 RocksDB的写缓存(即LSM树的最低一级)名为memtable,对应HBase的MemStore;读缓存名为block cac…

基于QT的图书管理系统

获取代码: 知识付费时代,低价有偿获取代码,请理解! (1) 下载链接: 后发(2) 添加博主微信获取(有偿),备注来源: mryang511688(3) 快速扫码咨询: 项目描述 技术:C、QT等 摘要&#…

(1)(1.8) Hondex声纳

文章目录 前言 1 推荐的硬件 2 连接和配置 3 参数说明 前言 Hondex HE-8S 是一款回声测深仪(又称水下声纳),测深范围 100m,内置 GPS 和 NMEA 输出,可由 ArduPilot 使用。其他 Hondex 声纳也可以使用,但…

redis学习(三)——java整合redis

Jedis Jedis可以用于java连接redis数据库 新建一个maven项目&#xff0c;导入Jedis依赖 <dependency><groupId>org.junit.jupiter</groupId><artifactId>junit-jupiter</artifactId><version>RELEASE</version><scope>test…

mcgsTpc屏与施耐德TM218PLC通讯说明

一、 硬件连接 1、PLC通讯接口说明&#xff1a; 2、通讯电缆图&#xff1a; 二、PLC设置 1. 配置端口&#xff1a; 双击串行线路—弹出右侧设置窗口---设置串口通讯参数 2. 添加MODBUS协议。 ① 右击串口线路&#xff0c;选择添加设备&#xff1a; ② 选择现场总线&#xf…

Studio One6.5中文版本更新下载

供所有 Studio One 6 和 Studio 用户下载&#xff0c;以下是部分改进&#xff01; Studio One 6.5 更新功能&#xff1a; 空间音频制作工作流程与杜比全景声集成 支持杜比全景声双耳耳机监听 字幕和功能铺增强功能 整合杜比全景声的空间音频制作工作流程 Studio One 6.5…

损失函数总结(三):BCELoss、CrossEntropyLoss

损失函数总结&#xff08;三&#xff09;&#xff1a;BCELoss、CrossEntropyLoss 1 引言2 损失函数2.1 BCELoss2.2 CrossEntropyLoss 3 总结 1 引言 在前面的文章中已经介绍了介绍了一系列损失函数 (L1Loss、MSELoss)。在这篇文章中&#xff0c;会接着上文提到的众多损失函数继…

关于路由转发

路由表的作用 路由表的作用&#xff1a;目标网络匹配路由表&#xff0c;从相应网络转发&#xff1b;不匹配路由表&#xff0c;丢弃或转发至默认路由器。 路由转发的原理 根据IP地址找到目标网络&#xff0c;由应路由器解封装查看目标网络是否可达&#xff0c;重新封装进行转…

什么是数据中心的测试负载?

数据中心的测试负载是指在数据中心环境中进行的负载测试&#xff0c;以评估数据中心的性能、可靠性和可扩展性。负载测试是通过模拟实际使用情况&#xff0c;向数据中心的系统和组件施加各种类型的负载&#xff0c;以确定其在不同负载条件下的表现和响应能力。 通过模拟高负载情…

GCE的安装和使用

GCE的安装和使用 GCE的安装使用1. GCE的安装2. GCE的使用补充&#xff1a;一个简单的R脚本——kmerpdf.R&#xff0c;用于绘制kmer的种类和数量分布图 GCE的安装使用 一个基因组评估软件。其他同类型软件Genomescope 1. GCE的安装 Github官网&#xff1a;https://github.com…

css-渐变色矩形

效果图&#xff1a; 代码&#xff1a; html: <!DOCTYPE html> <html><head><meta charset"utf-8"><meta name"viewport" content"initial-scale1.0, user-scalableno" /><title></title><link …

一文讲解电源技术中专为准谐振转换器供电 高性能电流模式控制器NCP1380BDR2G

NCP1380BDR2G是一款高性能器件&#xff0c;旨在为准谐振转换器供电。该控制器基于专属的谷锁闭系统&#xff0c;可以在功率负载变轻时进行切换并降低开关频率。这样将产生稳定的运行&#xff0c;即使在漏极-源极谷中总是触发的开关事件下也是如此。此系统可在低至第 4 个谷的条…

Zip密码忘记了,如何破解密码?

Zip压缩包设置了密码&#xff0c;解压的时候就需要输入正确对密码才能顺利解压出文件&#xff0c;正常当我们解压文件或者删除密码的时候&#xff0c;虽然方法多&#xff0c;但是都需要输入正确的密码才能完成。忘记密码就无法进行操作。 那么&#xff0c;忘记了zip压缩包的密…

YOLOv5 添加 OTA,并使用 coco、CrowdHuman数据集进行训练。

YOLO-OTA 第一步&#xff1a;拉取 YOLOv5 的代码第二步&#xff1a;添加 ComputeLossOTA 函数第二步&#xff1a;修改 train 和 val 中损失函数为 ComputeLossOTA 函数1、在 train.py 中 首先添加 ComputeLossOTA 库。2、在 train.py 修改初始化的损失函数3、在 train.py 修改一…

.net 支付宝 应用网页验签

验证签名接口 /// <summary>/// 验证网关/// </summary>/// <returns></returns>[Route("gatewayVerify"), HttpPost, AllowAnonymous, NonUnify]public async Task<dynamic> gatewayVerify(){var Request App.HttpContext.Request;…

文件恢复怎么做?学会这3招!快速恢复文件!

“大家平常文件丢失时有什么简单的方法来恢复文件吗&#xff1f;我总是很粗心会将文件丢失&#xff0c;但是又不知道应该如何恢复&#xff0c;请大家给我出出主意吧&#xff01;” 如果我们保存在电脑上的文件丢失了&#xff0c;可能会影响到我们正常工作的进度。但在实际使用电…

有了for循环 为什么还要forEach?

js中那么多循环&#xff0c;for for…in for…of forEach&#xff0c;有些循环感觉上是大同小异今天我们讨论下for循环和forEach的差异。for循环和forEach都是用于遍历数组或类数组对象的工具。它们之间有一些区别&#xff0c;使用哪个取决于具体的需求。 本质区别: for 循环是…

cleanmymacX4.14免费版mac清除浏览器缓存软件

当我们使用浏览器访问网站时&#xff0c;浏览器会自动缓存一些数据&#xff0c;比如网页缓存、DNS缓存、插件缓存、SSL证书缓存和Cookie缓存等。虽然有些缓存可以提高浏览器的使用体验&#xff0c;但是缓存过多也会导致一些问题&#xff0c;比如网页更新后浏览器仍然显示旧的内…