大模型 | NEFTune之引入随机噪声对大模型训练的收益

news2025/1/23 5:56:29

大模型 | NEFTune之引入随机噪声对大模型训练的收益

paper中提到,在模型foward过程中,对inputs_embedding增加适度的随机噪声,会带来显著的收益。

Paper: https://arxiv.org/pdf/2310.05914.pdf
Github: https://github.com/neelsjain/NEFTune

文章目录

  • 大模型 | NEFTune之引入随机噪声对大模型训练的收益
  • 理论
  • 一. 实践方法
    • 1.1 等待Hugging发布该功能
    • 1.2 直接封装model
    • 1.3 改写compute_loss


理论

核心是输入经过Embedding层后,再加入一个均匀分布的噪声,噪声的采样范围为 [ − α L d , α L d ] [-\frac{\alpha}{\sqrt{Ld}},\frac{\alpha}{\sqrt{Ld}}] [Ld α,Ld α]之间,其中 α \alpha α为噪声超参,L为输入长度,d为Embedding层维度(即hidden维度)
在这里插入图片描述
在AlpacaEval榜单上,利用GPT4作为评分器,在多个数据上微调Llama2-7B模型,NEFTune方法相较于直接微调方法,均有显著提高。
在这里插入图片描述
可以缓解模型在指令微调阶段的过拟合现象,可以更好的利用预训练阶段的知识内容。

一. 实践方法

1.1 等待Hugging发布该功能

进度:等待hugging face正式发布此功能,2023-10-26

[10/17/2023] NEFTune has been intregrated into the Huggingface’s TRL (Transformer Reinforcement Learning) library. See Annoucement.

1.2 直接封装model

进度:直接对模型进行如下封装,原理是对model.embed_tokens.forward()进行改写,经实践,这种方法不管用,会报堆栈溢出的error。

from torch.nn import functional as F

def NEFTune(model, noise_alpha=5)
    def noised_embed(orig_embed, noise_alpha):
        def new_func(x):
            # during training, we add noise to the embedding
            # during generation, we don't add noise to the embedding
            if model.training:
                embed_init = orig_embed(x)
                dims = torch.tensor(embed_init.size(1) * embed_init.size(2))
                mag_norm = noise_alpha/torch.sqrt(dims)
                return embed_init + torch.zeros_like(embed_init).uniform_(-mag_norm, mag_norm)
            else:
                return orig_embed(x)
        return new_func
    ##### NOTE: this is for a LLaMA model ##### 
    ##### For a different model, you need to change the attribute path to the embedding #####
    model.base_model.model.model.embed_tokens.forward = noised_embed(model.base_model.model.model.embed_tokens, noise_alpha)
    return model

1.3 改写compute_loss

进度:loss能够正常计算,但optimzer会报错,可能与精度有关,尚未解决

由于损失函数是自己写的,因此尝试在model(**input)前,追加噪声代码。注意,原先传入model的是input_ids,而当下由于我们将inputs_embeds增加了噪声,因此传入model的将直接替换为inputs_embeds,代码如下

class TargetLMLossNeft(Loss):

    def __init__(self, ignore_index):
        super().__init__()
        self.ignore_index = ignore_index
        self.loss_fn = nn.CrossEntropyLoss(ignore_index=ignore_index)

    def __call__(self, model, inputs, training_args, return_outputs=False):
        input_ids = inputs['input_ids'] # B x L [3, 964]
        attention_mask = inputs['attention_mask'] # B x L 
        target_mask = inputs['target_mask'] # B x L

        ###  ----------------------------- add noise to embeds
        neftune_alpha = 5
        embed_device = model.base_model.model.model.embed_tokens.weight.device
        embeds_init = model.base_model.model.model.embed_tokens.forward(input_ids).to(embed_device) # 先forward一下, 变成B X L X hidden_state
        # embed_device = model.model.embed_tokens.weight.device
        # embeds_init = model.model.embed_tokens.forward(input_ids).to(embed_device)

        input_mask = attention_mask.to(embeds_init) # B x L
        input_lengths = torch.sum(input_mask, 1) # B, 计算每个sample的实际长度
        
        noise_ = torch.zeros_like(embeds_init).uniform_(-1,1) # B X L X hidden_state, 且值域在[-1,1]正态分布
        delta = noise_ * input_mask.unsqueeze(2) # 追加一个维度,由B X L 变成 B X L X hidden_state
        dims = input_lengths * embeds_init.size(-1)
        mag = neftune_alpha / torch.sqrt(dims)
        delta = (delta * mag.view(-1, 1, 1)).detach() # B X L X hidden_state
        inputs_embeds = delta + embeds_init
        ### ----------------------------- add noise to embeds
        

        # 模型前馈预测, 原来传入的是input_ids,而现在需要直接将增加了noise的inputs_embeds传入
        # outputs = model(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
        outputs = model(inputs_embeds=inputs_embeds, attention_mask=attention_mask, return_dict=True)
        logits = outputs["logits"] if isinstance(outputs, dict) else outputs[0] # 正常应该是torch.float32
        #logits.requires_grad = True # 奇怪,为什么这里会默认为False, 难道是因为上边的detach()

        # 将labels中不属于target的部分,设为ignore_index,只计算target部分的loss
        labels = torch.where(target_mask == 1, input_ids, self.ignore_index)
        shift_logits = logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()
        # Flatten the tokens
        loss = self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) # float32
        loss.requires_grad = True
        return (loss, outputs) if return_outputs else loss

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1136387.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

苹果官宣新品发布会 10月31日发布会与Mac有关

10 月 25 日消息,苹果宣布将于北京时间 10 月 31 日上午 8 点举行主题为“来势迅猛”的线上特别活动,届时或将有新品发布。 这场发布会与以往不同,将在北京时间 10 月 31 日上午 8 点举行。有很多猜测认为苹果届时会发布新款 Mac 电脑&#x…

dropbear-ssh2

编译: ./configure --prefix/home/lxin/workdir/install-dropbear --with-zlib/home/lxin/workdir/zlib/install-zlib-1.2.11/ CCx86_64-linux-gnu-gcc make clean make make install 生产秘钥 in bin dir: ./dropbearkey -t rsa -f ../sbin/dropbear_rsa_host…

赢得国际市场:小企业的跨境电商品牌策略指南

随着全球化的快速发展,跨境电子商务已经成为小企业突破国界、实现全球化梦想的有效途径。然而,成功的跨境电商经营并不仅仅依赖于产品质量和价格竞争力,品牌营销同样至关重要。本文将深入探讨小企业如何在跨境电商领域做好品牌营销。 一、了解…

感受webWorker

B站视频 git完整代码 之前遇到的场景 1、vxe表格计算1000多条数极值/算数平方根的时候。 2、大文件上传时计算hashCode时候 一、不使用webWorker 目录结构 代码 <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8&q…

使用Docker部署Apache Superset并实现公网远程访问

大数据可视化BI分析工具Apache Superset实现公网远程访问 文章目录 大数据可视化BI分析工具Apache Superset实现公网远程访问前言1. 使用Docker部署Apache Superset1.1 第一步安装docker 、docker compose1.2 克隆superset代码到本地并使用docker compose启动 2. 安装cpolar内网…

上帝视角看支付总架构解析

文章目录 1. 支付全局分层2. 交易服务层2.1 服务平台的支付架构2.2 架构的支付部分2.3 架构的清结算部分2.4 完整的架构 3. 支付服务层3.1 支付接收部分3.2 支付处理部分3.3 清结算部分 4. 清算服务层4.1 常见清算组织4.2 银联清算业务4.3 网联清算业务4.3.1 网联支持的业务4.3…

【Linux】 rpm安装包保存到本地并批量安装

目录 一、开启rpm安装包缓存到本地仓库 1. 修改yum.conf文件 2. 清理yum缓存 3. yum命令安装软件包 二、如何将rpm安装包保存到指定目录 方法一&#xff1a;yumdownloader 1. 安装yum-utils  2. yumdownloader命令参数说明 3. yumdownloader安装示例 方法二&#xff…

用Notepad++写java代码

步骤 1.新建&#xff0c;写代码 2.写好之后存为java文件 3.打开命令行 cd 对应位置 javac xxx.java &#xff08;如有中文&#xff09; java xxxdebug 1.错误: 编码utf-8的不可映射字符 这是代码里有中文&#xff0c;编译时加上-encoding utf-8即可 2.错误: 程序包xxx不存在…

ETO制造商目前面临的六大挑战,如何应对?

与离散制造、库存制造不同&#xff0c;按订单设计制造&#xff08;ETO&#xff09;行业面临着一系列独特的挑战。从复杂的产品设计到与客户的密切联系&#xff0c;按订单生产的每件产品都不尽相同。 如果采用按订单生产方式制造产品&#xff0c;管理者总是会想方设法采购最好的…

基于springboot实现书籍学习平台管理系统项目【项目源码+论文说明】

基于springboot实现书籍学习平台管理系统演示 摘要 首先,论文一开始便是清楚的论述了平台的研究内容。其次,剖析平台需求分析,弄明白“做什么”,分析包括业务分析和业务流程的分析以及用例分析,更进一步明确平台的需求。然后在明白了平台的需求基础上需要进一步地设计平台,主要…

礼品家居建材行业出口管理ERP解决方案

根据“一带一路”白皮书显示&#xff0c;2013至2022年&#xff0c;中国与共建国家进出口总额累计19.1万亿美元&#xff0c;年均增长6.4%&#xff1b;与共建国家双向投资累计超过3800亿美元。随着“一带一路”高质量共建&#xff0c;第134届广交会第二期打造的“大家居”主体概念…

Python教程:csv如何保存字典数据

下面是一个示例代码&#xff0c;它将字典数据保存到CSV文件中&#xff1a; #我的Python教程 #微信公众号&#xff1a;wdPython首先创建了一个包含字典数据的列表dict_data。然后&#xff0c;我们使用csv.DictWriter()函数创建一个CSV写入对象&#xff0c;指定了字典中的键作为…

智慧燃气,如何能够防患于未“燃”!

关键词&#xff1a;智慧燃气、智慧燃气建设、智慧燃气平台、智慧燃气系统、燃气数字化 天然气作为一种优质的清洁能源&#xff0c;在改善大气污染中起到非常重要的作用。国家也在将天然气发展成为我国主体能源之一。天然气的快速发展也给智慧燃气带来了光明的前途&#xff0c;…

【python】乱码的处理总结

Python 系列 如果你在Python中遇到了乱码问题&#xff0c;可能是由于字符编码不匹配导致的。以下是一些可能的解决方法&#xff1a; &#xff08;1&#xff09; 确认编码格式&#xff1a;首先要确认你的数据的实际编码格式。常见的编码格式包括UTF-8、GBK、GB2312等。确定正确…

goland无法调试问题解决

goland 无法调试问题解决 golang 版本升级后&#xff0c;goland 无法进行调试了 首先请看自己下载的版本是否有误 1.apple系 M系列芯片的 arm64版本 2.apple系 intel系列芯片的x86_64 3.windows系 intel解决如下&#xff1a; 查看gopath ericsanchezErics-Mac-mini gww-api…

一克商评|未来向外输出自动驾驶技术和解决方案的中国企业会越来越多

封面新闻记者 孟梅 欧阳宏宇 雷强 蔡世奇 付文超 小马智行获沙特新未来城1亿美元投资&#xff0c;并将成立合资公司 小马智行宣布获得沙特阿拉伯王国新未来城&#xff08;NEOM&#xff09;及旗下投资基金NIF&#xff08;NEOM Investment Fund&#xff09;的1亿美元投资。同时…

生成树协议:监控 STP 端口和交换机

什么是生成树协议 生成树协议 &#xff08;STP&#xff09; 用于网络交换机&#xff0c;以防止循环和广播风暴。在局域网 &#xff08;LAN&#xff09; 中&#xff0c;两条或多条冗余路径可以连接到同一网段。当交换机或网桥从所有可用端口传输帧时&#xff0c;这些帧开始在网…

金属纳米颗粒通过水基剥离方案使用嵌段共聚物模板

引言 随着纳米结构表面和界面在广泛的科学和技术应用中变得越来越重要&#xff0c;确定可扩展和廉价的方法来实现这些变成了一个关键的挑战。特别是有序、非密集、表面支撑的金属纳米颗粒的大面积阵列的制造&#xff0c;由于其在不同领域如等离子体增强薄膜太阳能电池中的应用…

基于java+springboot的人事招聘信息网站

运行环境 开发语言&#xff1a;Java 框架&#xff1a;springboot JDK版本&#xff1a;JDK1.8 服务器&#xff1a;tomcat7 数据库&#xff1a;mysql 数据库工具&#xff1a;Navicat11 开发软件&#xff1a;eclipse/myeclipse/idea Maven包&#xff1a;Maven 项目介绍 开发过程…