大模型增量预训练新技巧:解决灾难性遗忘

news2025/1/18 18:49:19

大家好,目前不少开源模型在通用领域具有不错的效果,但由于缺乏领域数据,往往在一些垂直领域中表现不理想,这时就需要增量预训练和微调等方法来提高模型的领域能力。

但在领域数据增量预训练或微调时,很容易出现灾难性遗忘现象,也就是学会了垂直领域知识,但忘记了通用领域知识,之前介绍过增量预训练以及领域大模型训练技巧。

今天给大家带来一篇增量预训练方法-Llama-Pro,对LLMs进行Transformer块扩展后,增量预训练过程中仅对新增块进行训练,有效地进行模型知识注入,并且极大程度地避免灾难性遗忘。

图片

LLaMA Pro: Progressive LLaMA with Block Expansion

LLaMA Pro: Progressive LLaMA with Block Expansion
Paper: https://arxiv.org/abs/2401.02415
Github: https://github.com/TencentARC/LLaMA-Pro

文章目录

    • 技术交流群
    • 用通俗易懂方式讲解系列
    • 块扩展方法
    • 实验细节
    • 讨论分析
    • 写在最后

技术交流群

前沿技术资讯、算法交流、求职内推、算法竞赛、面试交流(校招、社招、实习)等、与 10000+来自港科大、北大、清华、中科院、CMU、腾讯、百度等名校名企开发者互动交流~

我们建了大模型面试与技术交流群, 想要进交流群、需要源码&资料、提升技术的同学,可以直接加微信号:mlc2060。加的时候备注一下:研究方向 +学校/公司+CSDN,即可。然后就可以拉你进群了。

方式①、微信搜索公众号:机器学习社区,后台回复:加群
方式②、添加微信号:mlc2060,备注:技术交流

资料1
在这里插入图片描述

用通俗易懂方式讲解系列

  • 用通俗易懂的方式讲解:自然语言处理初学者指南(附1000页的PPT讲解)
  • 用通俗易懂的方式讲解:1.6万字全面掌握 BERT
  • 用通俗易懂的方式讲解:NLP 这样学习才是正确路线
  • 用通俗易懂的方式讲解:28张图全解深度学习知识!
  • 用通俗易懂的方式讲解:不用再找了,这就是 NLP 方向最全面试题库
  • 用通俗易懂的方式讲解:实体关系抽取入门教程
  • 用通俗易懂的方式讲解:灵魂 20 问帮你彻底搞定Transformer
  • 用通俗易懂的方式讲解:图解 Transformer 架构
  • 用通俗易懂的方式讲解:大模型算法面经指南(附答案)
  • 用通俗易懂的方式讲解:十分钟部署清华 ChatGLM-6B,实测效果超预期
  • 用通俗易懂的方式讲解:内容讲解+代码案例,轻松掌握大模型应用框架 LangChain
  • 用通俗易懂的方式讲解:如何用大语言模型构建一个知识问答系统
  • 用通俗易懂的方式讲解:最全的大模型 RAG 技术概览
  • 用通俗易懂的方式讲解:利用 LangChain 和 Neo4j 向量索引,构建一个RAG应用程序
  • 用通俗易懂的方式讲解:使用 Neo4j 和 LangChain 集成非结构化知识图增强 QA
  • 用通俗易懂的方式讲解:面了 5 家知名企业的NLP算法岗(大模型方向),被考倒了。。。。。
  • 用通俗易懂的方式讲解:NLP 算法实习岗,对我后续找工作太重要了!。
  • 用通俗易懂的方式讲解:理想汽车大模型算法工程师面试,被问的瑟瑟发抖。。。。
  • 用通俗易懂的方式讲解:基于 Langchain-Chatchat,我搭建了一个本地知识库问答系统
  • 面试了字节大模型算法岗(实习),快被问哭了。。。。

块扩展方法

块扩展,顾名思义,就是在原始模型中每个Transformer块或者某几个Transformer块后增加一个Transformer块,但为了保持扩展后的模型输出保持不变,需要增加的块为恒等块(输入输出相同),如下图所示。

图片

在构建恒等块过程中,主要是将多头注意力层和FFN层中的最后一个线性层(Linear)权重置为0变成Zero-Linear,即可保持经过该块的输入输出一致。

PS:论文附录A中写了大段的推导公式来证明,在此不做过多介绍。

块的增加方式是,对原始模型的 个Transformer块分成 组,每组中包含 个Transformer块,对于每组后添加 个恒等块。代码实现具体如下:

model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16)
ckpt = model.state_dict()

# original_layers是模型原始层数,layers是模型最后达到层数
split = int(original_layers / (layers - original_layers))

layer_cnt = 0

output = {}
for i in range(original_layers):
    for k in ckpt:
        if ('layers.' + str(i) + '.') in k:
            output[k.replace(('layers.' + str(i) + '.'), ('layers.' + str(layer_cnt) + '.'))] = ckpt[k]
    layer_cnt += 1
    if (i+1) % split == 0:
        for k in ckpt:
            if ('layers.' + str(i) + '.') in k:
                if 'down_proj' in k or 'o_proj' in k:
                    output[k.replace(('layers.' + str(i) + '.'), ('layers.' + str(layer_cnt) + '.'))] = torch.zeros_like(ckpt[k])
                else:
                    output[k.replace(('layers.' + str(i) + '.'), ('layers.' + str(layer_cnt) + '.'))] = ckpt[k]
        layer_cnt += 1
    
assert layer_cnt==layers
for k in ckpt:
    if not 'layers' in k:
        output[k] = ckpt[k]

torch.save(output, output_path)

实验细节

数据由代码和数学组成,其中代码数据采用The-Stack-Dedup数据集中Python语言部分共22B Token,数学数据采用Proof-Pile-2数据集中AlgebraicStack、OpenWebMath和ArXiv部分共55B,详细如下表所示。

图片

数据分布

基础模型为LLaMA2-7B模型,通过块扩展方法将32层模型扩展到40层,其中 、 、 ,每个组从4个Transformer块扩展到5个Transformer块。

对于代码和数学数据进行增量预训练,批量大小为1024,序列最大长度为4096,预热比率为6%,学习率为2e-4,采用余弦学习率调度器,BF16混合精度训练,权重衰减为0.1。使用16个NVIDIA H800 GPU进行了15900个步骤的训练,大约耗费2830个GPU/小时。

在ARC、HellaSwag、MMLU、TruthfulQA、Winogrande、GSM8K、GSM8K-PoT、HumanEval、MBPP等多个评测数据集中进行评测,可以看出,在保持通用任务能力不下降的情况下,数学和代码能力较原始LLaMA2-7B模型有很大提升。

图片

图片

讨论分析

对比块扩展方法与正常训练和Lora方法之间的区别,采用TRACE基准利用总体性能(OP)和逆向转移(BWT)指标进行评估。,如下表所示,块扩展方法整体提升较大。

图片

对比块个数对块扩展方法的影响,进行了不同个数块的实验,并且对比了MoE的方法,训练损失如下,MoE方法的损失下降程度与添加四个块相当。

图片

在代码和法律(16.7B)领域数据下进行增量预训练,在通用任务以及领域任务上比较不同个数块之间的差异,同时比较扩展块全部添加到模型底部或顶部之间的差别,如下所示。可以发现块个数为8时效果最佳,并且不能直接将扩展块全部堆积在头部或尾部,需要分开插入。

图片

写在最后

该方法主要通过增加恒定块扩展模型层数,使模型在增量训练过程中仅训练新增层、冻结原始层,保持模型原有能力,防止模型出现灾难性遗忘现象。

但有两点存疑:

  • 目前来说mistral要好于llama,为啥不用mistral进行实验

  • 不用恒定块,性能会差多少

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1432133.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

LLM大模型

LLM 学习链接 : 大语言模型 LLM行业背景和市场需求 大模型的涌现能力 大模型核心前沿 大模型应用范式和职业规划

大数据 - Spark系列《四》- Spark分布式运行原理

Spark系列文章: 大数据 - Spark系列《一》- 从Hadoop到Spark:大数据计算引擎的演进-CSDN博客 大数据 - Spark系列《二》- 关于Spark在Idea中的一些常用配置-CSDN博客 大数据 - Spark系列《三》- 加载各种数据源创建RDD-CSDN博客 目录 🍠…

200行C++代码写一个网络调试助手(TCP服务端TCP客户端)

前言 今天分享一个200行C代码写成的QT网络调试助手。 可以先看看效果 。 因为我不喜欢用QT Designer,因此我用的组件都是使用代码布局的,所以需要设计一下布局。 界面是参考的之前写的串口助手,就是把里面的逻辑改了改,因此外观…

关于网络面试题汇总

什么是TCP/IP五层模型?它们的作用是啥?基于TCP/IP实现的应用(层协议)有哪些? TCP/IP五层模型,从上向下分别是: 应用层:应用程序本身,应用层的作用是负责应用程序之间的…

比特币ETF广告战大爆发!

作者:秦晋 贝莱德主动发起广告攻势。 2月1日,据外媒Cryptoslate报道,贝莱德在提交给美国SEC的一份文件中显示,其提出一项在建筑物侧面投影比特币ETF广告计划。 据介绍,广告内容为:「IBIT」信号是一个以迈阿…

IP风险画像在企业网络安全中应用

随着企业数字化的不断深入,网络安全问题日益突显。IP风险画像作为一种综合性的网络安全工具,为企业提供了更全面的风险评估和防范手段。本文将结合一个实际案例,深入探讨IP风险画像在企业网络安全中的成功应用。 案例背景 一家大型金融机构…

VS2019 添加程序包

dotnet add package AlibabaCloud.SDK.Bailian20230601 来提示添加程序包 选择菜单栏 项目----管理NuGet程序包 输入程序包的名称,然后添加即可, 这只是给当前工程添加,并不是给VS添加,所以你打开新工程,需要使用的话…

详解WebRTC rtc::Thread实现

rtc::Thread介绍 rtc::Thread类不仅仅实现了线程这个执行器(比如posix底层调用pthread相关接口创建线程,管理线程等),还包括消息队列(message_queue)的实现,rtc::Thread启动后就作为一个永不停止的event l…

【图论】基环树

基环树其实并不是树,是指有n个点n条边的图,我们知道n个点n-1条边的连通图是树,再加一条边就会形成一个环,所以基环树中一定有一个环,长下面这样: 由基环树可以引申出基环内向树和基环外向树 基环内向树如…

【开源】WordPress一键崩溃宕机插件(整活娱乐)

插件介绍 可一键实现Wordpress崩溃宕机的整活向插件(请勿用于非法途径,仅供整活娱乐)。鼓励关注网站性能的提升,以提供更好的用户体验,提倡为用户提供良好体验和高效速度的原则。 介绍 长期以来,人们都在…

WordPress从入门到精通【安装部署】

初识WordPress WordPress,简称WP,其简称的由来是取英文单词“word”与“press”的首字母 WP中文官网 1WP主站(英文) 官方标称,已有43%的网站在使用WordPress WordPress亮点 WP使用PHP语言开发,兼容性极…

深度学习/自动驾驶数据集大集合(目标检测/图像分割/语义分割/图像分类/)

深度学习和自动驾驶技术的发展离不开高质量的数据集,这些数据集对于训练和验证各种自动驾驶算法和模型起着至关重要的作用。深度学习/自动驾驶数据集大集合是一项汇集了多种场景、多种数据类型的数据资源,旨在为深度学习和自动驾驶领域的研究者和从业者提…

设计模式-行为型模式(上)

行为型模式用于描述程序在运行时复杂的流程控制,即描述多个类或对象之间怎样相互协作共同完成单个对象都无法单独完成的任务,它涉及算法与对象间职责的分配。 行为型模式分为类行为模式和对象行为模式,前者采用继承机制来在类间分派行为&…

【服务器】RAID(独立磁盘冗余阵列)

RAID(独立磁盘冗余阵列) 一、RAID的介绍二、RAID的分类#2-1 RAID 02-2 RAID 1#2-3 RAID 32-4 RAID 52-5 RAID 62-6 RAID 10(先做镜像,再做条带化)2-7 RAID 01(先做条带,再做镜像)2-8 RAID比较 三、磁盘阵列…

CSDN文章导出工具

源码地址: github:https://github.com/lishuangquan1987/CSDNExportergitee:https://gitee.com/lishuangquan1987/csdnexporter 介绍 最近有CSDN博客导出来的需求,翻看了很多开源工具,都不能用或者不好用,于是决定自己做一个。…

机器学习6-逻辑回归

逻辑回归是机器学习中一种常用于二分类问题的监督学习算法。虽然名字中包含“回归”,但实际上它用于分类任务,特别是对于输出为两个类别的情况。逻辑回归通过使用 logistic 函数将输入映射到一个在0,1范围内的概率值,然后根据这个概率值进行分类。 以下是逻辑回归的基本概念…

线程池,定时器以及阻塞队列(生产者/消费者模型)

💓 博客主页:从零开始的-CodeNinja之路 ⏩ 收录专栏:线程池,定时器以及阻塞队列(生产者/消费者模型) 🎉欢迎大家点赞👍评论📝收藏⭐文章 实现线程池,定时器以及阻塞队列,生产者/消费者模型 线程池线程池…

c++用户管理信息(类指针数组)

用户管理信息--类指针数组 类示意图select类示意图MyIterator示意图VetorCstu示意图ClassStu示意图 项目源代码selectselect.hselect.cpp MyIteratorMyIterator.hMyIterator.cpp VetorCstuVetorCstu.hVetorCstu.cpp ClassStuClassStu.hClassStu.cpp main源码 总结---数组管理指…

中科大计网学习记录笔记(五):协议层次和服务模型

前言: 学习视频:中科大郑烇、杨坚全套《计算机网络(自顶向下方法 第7版,James F.Kurose,Keith W.Ross)》课程 该视频是B站非常著名的计网学习视频,但相信很多朋友和我一样在听完前面的部分发现信…

2024.2.4日总结(小程序开发1)

小程序开发和普通网页开发的区别 运行环境不同 网页运行在浏览器环境中,小程序运行在微信环境中 API不同 由于运行的环境不同,所以小程序中无法调用DCM和BOM的API,但是可以调用微信环境提供的各种API,如:地理定位&…