交叉熵损失的“替代品”:基于最优传输思想设计的分类损失函数EMO

news2024/11/26 21:25:51

3e294d764b84b2c8bec67237225a6fcc.gif

©PaperWeekly 原创 · 作者 | 苏剑林

单位 | 月之暗面

研究方向 | NLP、神经网络

众所周知,分类任务的标准损失是交叉熵(Cross Entropy,等价于最大似然 MLE,即 Maximum Likelihood Estimation),它有着简单高效的特点,但在某些场景下也暴露出一些问题,如偏离评价指标、过度自信等,相应的改进工作也有很多,此前我们也介绍过一些,比如《再谈类别不平衡问题:调节权重与魔改Loss的对比联系》、《如何训练你的准确率?》、《缓解交叉熵过度自信的一个简明方案》[1] 等。

由于 LLM 的训练也可以理解为逐 token 的分类任务,默认损失也是交叉熵,因此这些改进工作在 LLM 流行的今天依然有一定的价值。

在这篇文章中,我们介绍一篇名为 EMO 的工作,它基于最优传输思想提出了新的改进损失函数,声称能大幅提高 LLM 的微调效果。其中细节如何?让我们一探究竟。

9af15e88d389d79edf7815c36958b367.png

论文标题:

EMO: Earth Mover Distance Optimization for Auto-Regressive Language Modeling

论文地址:

https://arxiv.org/abs/2310.04691

b7da6b442750c464271151154ffac23e.png

概率散度

假设 是模型预测的第 个类别的概率,, 则是目标类别,那么交叉熵损失为

4e49e7667a0eab123aee5aa7f43be175.png

如果将标签 用one hot形式的分布 表示出来(即 ),那么它可以重写成

15f3fef3123d369c1bb92fd2735dc823.png

这个形式同时适用于非 one hot 的标签 (即软标签),它等价于优化 的 KL 散度:

920249c3246d01db7a0a6fe7d6e9c993.png

当 给定时,最右端第一项就是一个常数,所以它跟交叉熵目标是等价的。

这个结果表明,我们在做 MLE,或者说以交叉熵为损失时,实则就是在最小化目标分布和预测分布的 KL 散度。由于 KL 散度的一般推广是 f 散度(参考《f-GAN简介:GAN模型的生产车间》),所以很自然想到换用其他 f 散度或许有改良作用。事实上,确实有不少工作是按照这个思路进行的,比如《缓解交叉熵过度自信的一个简明方案》[1] 介绍的方法,其论文的出发点是 “Total Variation 距离”,也是 f 散度的一种。

8a0fc56d2e4c86bd41b708b3d73b9286.png

最优传输

不过,每种 f 散度或多或少有些问题,要说概率分布之间的理想度量,当属基于最优传输思想的“推土机距离(Earth Mover's Distance,EMD)”,不了解的读者可以参考一下笔者之前写的《从Wasserstein距离、对偶理论到WGAN》。

简单来说,推土机距离定义为两个分布之间的最优传输成本:

aa7101fc4368636085e4b5f6814bcdc2.png

这里的 说的是 是任意以 为边缘分布的联合分布, 是实现给定的成本函数,代表“从 搬运到 的成本”, 是下确界,意思就是说将最低的运输成本作为 之间的差异度量。正如基于 f 散度的 Vanilla GAN 换成基于最优传输的 Wasserstein GAN 能够更好的收敛性质,我们期望如果将分类的损失函数换成两个分布的 W 距离,也能收敛到更好的结果。

当 是 one hot 分布时,目标分布就是一个点 ,那么就无所谓最不最优了,传输方案就只有一个,即把 的所有东西都搬到同一个点 ,所以此时就有

924982eed10f8846f981223631237729.png

如果 是一般的软标签分布,那么 的计算是一个线性规划问题,求解起来比较复杂,由于 所定义的分布也属于 ,那么我们有

36335a1ea6251fc8f0f968c657d3a0c3.png

这是一个容易计算的上界,也可以作为优化目标,式(5)则对应 ,其中 是“克罗内克 函数” [2]。

69f5a4dbbf9e258c35cd86be9c87a9b6.png

成本函数

现在回到原论文所关心的场景——LLM 的微调,包括二次预训练和微调到下游任务等。正如本文开头所述,LLM 的训练可以理解为逐 token 的分类任务(类别即所有 token),每个标签是 one hot 的,所以适用于式(5)。

式(5)还差成本函数 还没定下来。如果简单地认为只要 ,那么成本都是 1,即 ,那么

b7e93bc54fa4f4732d6e69ec0bb37b15.png

这其实就是在最大化准确率的光滑近似(参考《函数光滑化杂谈:不可导函数的可导逼近》[4])。但直觉上,所有 都给予同样程度的惩罚似乎过于简单了,理想情况下应该根据相似度来给每个不同的 设计不同的成本,即相似度越大,传输成本越低,那么我们可以将传输成本设计为

18e4823f4fe6ba2a078ebf874eadf196.png

这里的 是事先获取到 Token Embedding,原论文是将预训练模型的 LM Head 作为 Token Embedding 的,并且根据最优传输的定义成本函数是要实现给定的,因此计算相似度的 Token Embedding 要在训练过程中固定不变。

有了成本函数后,我们就可以计算

2a6e475dae3cc6247a05e0e19897bb4d.png

这就是 EMO(Earth Mover Distance Optimization)最终的训练损失。由于 embedding_size 通常远小于 vocab_size,所以先算 能明显降低计算量。

92ad09cb9870dbf7fc2513687671530c.png

实验效果

由于笔者对 LLM 的研究还处于预训练阶段,还未涉及到微调,所以暂时没有自己的实验结果,只能先跟大家一起看看原论文的实验。不得不说,原论文的实验结果还是比较惊艳的。

首先,是小模型上的继续预训练实验,相比交叉熵(MLE)的提升最多的有 10 个点,并且是全面 SOTA:

95809b1c7d33e793cf01cea1c267afb7.png▲ 小模型上的继续预训练对比实验

值得一提的是,这里的评价指标是 MAUVE,越大越好,它提出自《MAUVE: Measuring the Gap Between Neural Text and Human Text using Divergence Frontiers》[3],是跟人工评价最相关的自动评测指标之一。此外,对比方法的 TaiLr 我们曾在《缓解交叉熵过度自信的一个简明方案》[1] 简单介绍过。

可能有读者想 EMO 更好是不是单纯因为评价指标选得好?并不是,让人意外的是,EMO 训练的模型,甚至 PPL 都更好(PPL 跟 MLE 更相关):

38011ea6e56874641ecc4bbf90770967.png

▲ 不同评价指标的对比

然后是将 LLAMA-7B/13B 微调到下游任务做 Few Shot 的效果,同样很出色:

d329dd806387f66e128de04281635f84.png

▲ LLAMA-7B:13B微调到下游任务的效果

最后对比了不同模型规模和数据规模的效果,显示出 EMO 在不同模型和数据规模上都有不错的表现:

a4fba2cff701c27afd3c32dd759f2eda.png

▲ 不同模型规模/数据规模上的效果

e15ff2d1e8689ffa14bba40597b3938c.png

个人思考

总的来说,原论文的“成绩单”还是非常漂亮的,值得一试。唯一的疑虑可能是原论文的实验数据量其实都不算大,不清楚进一步增大数据量后是否会缩小 EMO 和 MLE 的差距。

就笔者看来,EMO 之所以能取得更好的结果,是因为它通过 Embedding 算相似度,来为“近义词”分配了更合理的损失,从而使得模型的学习更加合理。因为虽然形式上 LLM 也是分类任务,但它并不是一个简单的对与错问题,并不是说下一个预测的 token 跟标签 token 不一致,句子就不合理了,因此引入语义上的相似度来设计损失对 LLM 的训练是有帮助的。可以进一步猜测的是,vocab_size 越大、token 颗粒度越大的情况下,EMO 的效果应该越好,因为 vocab_size 大了“近义词”就可能越多。

当然,引入语义相似度也导致了 EMO 不适用于从零训练,因为它需要一个训练好的 LM Head 作为 Token Embedding。当然,一个可能的解决方案是考虑用其他方式,比如经典的 Word2Vec 来事先训练好 Token Embedding,但这可能会有一个风险,即经典方式训练的 Token Embedding 是否会降低 LLM 能力的天花板(毕竟存在不一致性)。

此外,即便 Token Embedding 没问题,从零预训练时单纯用 EMO 可能还存在收敛过慢的问题,这是因为根据笔者在《如何训练你的准确率?》的末尾提出的损失函数视角:

首先寻找评测指标的一个光滑近似,最好能表达成每个样本的期望形式,然后将错误方向的误差逐渐拉到无穷大(保证模型能更关注错误样本),但同时在正确方向保证与原始形式是一阶近似。也就是说,为了保证(从零训练的)收敛速度,错误方向的损失最好能拉到无穷大,而 EMO 显然不满足这一点,因此将 EMO 用于从零训练的时候,大概率是 EMO 与 MLE 的某个加权组合,才能平衡收敛速度和最终效果。

b979cc917573c775ecf8b79e78c39ef5.png

文章小结

本文介绍了交叉熵损失的一个新的“替代品”——基于最优传输思想的 EMO,与以往的小提升不同,EMO 在 LLM 的微调实验中取得了较为明显的提升。

outside_default.png

参考文献

outside_default.png

[1] https://kexue.fm/archives/9526

[2] https://en.wikipedia.org/wiki/Kronecker_delta

[3] https://kexue.fm/archives/6620#正确率

更多阅读

f5461d10b07cdfa7e7d1423a862a8a04.png

edb7bc8f72dbbbad0de95c02b775d8c2.png

3a5428b7afe0811f8453026fa7fc1ddc.png

e26f73405569fb9fbcf8f4696553aa2e.gif

#投 稿 通 道#

 让你的文字被更多人看到 

如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢?答案就是:你不认识的人。

总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。 

PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学术热点剖析科研心得竞赛经验讲解等。我们的目的只有一个,让知识真正流动起来。

📝 稿件基本要求:

• 文章确系个人原创作品,未曾在公开渠道发表,如为其他平台已发表或待发表的文章,请明确标注 

• 稿件建议以 markdown 格式撰写,文中配图以附件形式发送,要求图片清晰,无版权问题

• PaperWeekly 尊重原作者署名权,并将为每篇被采纳的原创首发稿件,提供业内具有竞争力稿酬,具体依据文章阅读量和文章质量阶梯制结算

📬 投稿通道:

• 投稿邮箱:hr@paperweekly.site 

• 来稿请备注即时联系方式(微信),以便我们在稿件选用的第一时间联系作者

• 您也可以直接添加小编微信(pwbot02)快速投稿,备注:姓名-投稿

0b9889ae17df5e097eadd13ab3bafdac.png

△长按添加PaperWeekly小编

🔍

现在,在「知乎」也能找到我们了

进入知乎首页搜索「PaperWeekly」

点击「关注」订阅我们的专栏吧

·

·

f16a940ffb1121be82b8c1845025bb45.jpeg

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1160807.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

为什么重写 redisTemplate

为什么重写 redisTemplate 1.安装 redis 上传 redis 的安装包tar -xvf redis-5.0.7.tar.gzyum -y install gcc-cmakemake PREFIX/soft/redis installcd /soft/redis/bin./redis-server redis.conf 2. 集成 redisTemplate maven 依赖 <dependency><groupId>org…

React基础源码解析

前言&#xff1a; 前端魔术师卡颂的react学习视频&#xff08;1 搭建项目架构_哔哩哔哩_bilibili&#xff09;中提到了Rodrigo Pombo的一篇react源码教程&#xff1a;Build your own React 本文档分组旨在翻译和记录这篇文章的学习心得&#xff0c;作为react源码学习入门。 …

MySQL笔记--Ubuntu安装MySQL并基于C++测试API

目录 1--安装MySQL 2--MySQL连接 3--代码案例 1--安装MySQL # 安装MySQL-Server sudo apt install mysql-server# 设置系统启动时自动开启 sudo systemctl start mysql # sudo systemctl enable mysql# 检查MySQL运行状态 sudo systemctl status mysql# 进入MySQL终端 sudo…

VR全景在医院的应用:缓和医患矛盾、提升医院形象

医患关系一直以来都是较为激烈的&#xff0c;包括制度的不完善、医疗资源紧张等问题也时有存在&#xff0c;为了缓解医患矛盾&#xff0c;不仅要提升患者以及家属对于医院的认知&#xff0c;还需要完善医疗制度&#xff0c;提高医疗资源的配置效率&#xff0c;提高服务质量。 因…

vue3的ref源码解析

ref的实现原理 一句话总结: ref本身是个函数&#xff0c;该函数返回一个createRef函数&#xff0c;createRef函数又返回一个“经过类RefImpl实例化”的对象。 详情介绍: ref函数接收我们传入的一个简单类型或复杂类型value&#xff0c;后又将value传递给createRef函数&#xf…

【K8S】二进制安装

常见的K8S安装部署方式 ●Minikube Minikube是一个工具&#xff0c;可以在本地快速运行一个单节点微型K8S&#xff0c;仅用于学习、预览K8S的一些特性使用。 部署地址&#xff1a;https://kubernetes.io/docs/setup/minikube ●Kubeadm☆ Kubeadm也是一个工具&#xff0c;提…

利用Docker容器化构建可移植的分布式应用程序

目录 一、什么是Docker容器化 二、构建可移植的分布式应用程序的优势 三、构建可移植的分布式应用程序的步骤 四、推荐一款软件开发工具 随着云计算和容器化技术的快速发展&#xff0c;将应用程序容器化成为构建可移植的分布式应用程序的一种重要方式。Docker作为目前最为…

批量采集各类自媒体平台内容为word文档带图片软件【支持18家自媒体平台的爬取采集】

批量采集各类自媒体平台内容为word文档带图片软件介绍&#xff1a; 1、支持头条号、大鱼号、企鹅号、一点号、凤凰号、搜狐号、网易号、趣头条、东方号、时间号、惠头条、WiFi万能钥匙、新浪看点、简书、QQ看点、快传号、百家号、微信公众号的文章批量采集为docx文档并带图片。…

c++ Templates:The Complete Guide第二版英文版勘误

看到这里的时候觉得不对劲&#xff0c;一查&#xff0c;果然是写错了&#xff0c;Values应该改成Vs 12.4 Page 204, 12.4.2: s/Values is a nontype template parameter pack.../Vs is a nontype template parameter pack.../Page 204, 12.4.2: s/...provided for the templat…

CSS3设计动画样式

CSS3动画包括过渡动画和关键帧动画&#xff0c;它们主要通过改变CSS属性值来模拟实现。我将详细介绍Transform、Transitions和Animations 3大功能模块&#xff0c;其中Transform实现对网页对象的变形操作&#xff0c;Transitions实现CSS属性过渡变化&#xff0c;Animations实现…

嵌入式每日500(3)231103 (总线结构,存储器映射,启动配置,FLASH读、写、擦除介绍,CRC校验,选项字节,)

这里写目录标题 1.总线结构2.STM32F072VBT6存储器映射3.启动配置&#xff08;BOOT0&#xff0c;BOOT1&#xff09;4.FLASH存储器&#xff08;读、写、擦除&#xff09;5.CRC计算单元6.选项字节 1.总线结构 主模块&#xff08;2个&#xff09;Cortex-M0内核、DMA通道从模块&…

20.4 OpenSSL 套接字AES加密传输

在读者了解了加密算法的具体使用流程后&#xff0c;那么我们就可以使用这些加密算法对网络中的数据包进行加密处理&#xff0c;加密算法此处我们先采用AES算法&#xff0c;在网络通信中&#xff0c;只需要在发送数据之前对特定字符串进行加密处理&#xff0c;而在接收到数据后在…

【面试经典150 | 链表】随机链表的复制

文章目录 Tag题目来源题目解读解题思路方法一&#xff1a;哈希表递归方法二&#xff1a;哈希表方法三&#xff1a;迭代拆分节点 写在最后 Tag 【递归】【迭代】【链表】 题目来源 138. 随机链表的复制 题目解读 对一个带有随机指向的链表进行深拷贝操作。 解题思路 本题一共…

layui form表单 调整 label 宽度

这个可以调整所有label .layui-form-label {width: 120px !important; } .layui-input-block {margin-left: 150px !important; }情况是这样的&#xff0c;表单里有多个输入框&#xff0c;只有个别label 是长的&#xff0c;我就想调整一下个别长的&#xff0c;其它不变 <di…

小程序day02

目标 WXML模板语法 数据绑定 事件绑定 那麽問題來了&#xff0c;一次點擊會觸發兩個組件事件的話&#xff0c;該怎么阻止事件冒泡呢&#xff1f; 文本框和data的双向绑定 注意点: 只在标签里面用value“{{info}}”&#xff0c;只会是info到文本框的单向绑定&#xff0c;必须在…

1、循环依赖详解(一)

什么是循环依赖&#xff1f; 什么情况下循环依赖可以被处理&#xff1f; Spring是如何解决的循环依赖&#xff1f; 只有在setter方式注入的情况下&#xff0c;循环依赖才能解决&#xff08;错&#xff09; 三级缓存的目的是为了提高效率&#xff08;错&#xff09; 什么是循环…

在基于亚马逊云科技的湖仓一体架构上构建数据血缘的探索和实践

背景介绍 随着大数据技术的进步&#xff0c;企业和组织越来越依赖数据驱动的决策。数据的质量、来源及其流动性因此显得非常关键。数据血缘分析为我们提供了一种追踪数据从起点到终点的方法&#xff0c;有助于理解数据如何被转换和消费&#xff0c;同时对数据治理和合规性起到关…

Ajax学习笔记第8天

放弃该放弃的是无奈&#xff0c;放弃不该放弃的是无能&#xff0c;不放弃该放弃的是无知&#xff0c;不放弃不该放弃的是执着&#xff01; 【1. 聊天室小案例】 文件目录 初始mysql数据库 index.html window.location.assign(url); 触发窗口加载并显示指定的 url的内容 当前…

TSINGSEE青犀特高压输电线可视化智能远程监测监控方案

一、背景需求分析 特高压输电线路周边地形复杂&#xff0c;纵横延伸几十甚至几百千米&#xff0c;并且受所处地理环境和气候影响很大。传统输电线路检查主要依靠维护人员周期性巡视&#xff0c;缺乏一定的时效性&#xff0c;在巡视周期的真空期也不能及时掌握线路走廊外力变化…

AQS面试题总结

一&#xff1a;线程等待唤醒的实现方法 方式一&#xff1a;使用Object中的wait()方法让线程等待&#xff0c;使用Object中的notify()方法唤醒线程 必须都在synchronized同步代码块内使用&#xff0c;调用wait&#xff0c;notify是锁定的对象&#xff1b; notify必须在wait后执…