Virtual Data Augmentation: 虚拟数据扩增技术

news2025/2/23 12:32:47

听说过数据扩增(Data Augmentation),也听说过虚拟对抗训练(Virtual Adversarial Traning),但是我没想到会有人将其结合,谓之虚拟数据扩增(Virtual Data Augmentation)。这篇文章主要讲解EMNLP2021上的一篇论文Virtual Data Augmentation: A Robust and General Framework for Fine-tuning Pre-trained Models,该论文提出了一种鲁棒且通用的数据扩增方法,论文源码在https://github.com/RUCAIBox/VDA

论文开篇提到目前数据扩增存在的主要问题:产生数据多样性的同时如何保证其仍然在同一个语义空间中?简单地说,增强数据扩增的多样性很容易,核心就一个字:“乱”,例如许多数据扩增方法会随机打乱一个句子中token的位置,或者是随机删除某些token,随机插入某些token。这样虽然增强了样本的多样性,但是语义可能也会产生非常大的变化,甚至不再与原样本的语义相同。保持语义不变,或者说保证扩增后的样本和原样本在同一个语义空间中很容易,核心就是:“不要太乱”,例如通过同义词替换等,这种方法可以做到几乎不改变语义,但是数据多样性却不够,因为本质上还是同一句话

这两个需求实际上是矛盾的,我们所能做的只是尽力达到某种平衡。具体来说,作者所提出的方法包含两个重要部分:Embedding Augmentation以及Regularized Training

Embedding Augmentation

假设现在我们有句子「Time is enough for test」,对于每个位置的token,我们都可以将其替换为[MASK],然后通过MLM预测Vocabulary中所有token在该位置的概率,例如

[MASK] is enough for test

[MASK]位置输出的token及其概率为

Time  p=0.5
Day   p=0.3
Hours p=0.15
...

再比如

Times is enough for [MASK]

[MASK]位置输出的token及其概率为

test       p=0.5
evaluation p=0.3
experiment p=0.1
...

看到这里大家脑海中可能已经有了一个数据扩增的想法,就是利用MLM任务对句子中每个位置的token进行预测,然后根据预测概率随机挑选出一个token进行替换,例如上面的句子可能就会被替换为「Hours is enough for evaluation」。这确实是一种还不错的数据扩增方法,但是论文作者却并不是这么做的

为了描述简单,我们仅讨论对于给定句子 S S S中的一个token w ~ \tilde{w} w~进行扩增的情况(实际上句子 S S S中的所有token都会进行该操作),通过MLM任务我们可以预测出Vocabulary中所有单词在 w ~ \tilde{w} w~位置的概率
{ p ( w ^ 1 ∣ S ) , . . . , p ( w ^ V ∣ S ) } (1) \{p(\hat{w}_1\mid S),...,p(\hat{w}_V\mid S)\}\tag{1} {p(w^1S),...,p(w^VS)}(1)
其中, V V V是Vocabulary中的token数量

为了增强数据扩增的多样性,或者说引入某些噪声以增强抗干扰性,我们从高斯分布中随机采样出一个向量
ϵ ∼ N ( 0 , σ 2 ) (2) \epsilon \sim \mathcal{N}(0, \sigma^2)\tag{2} ϵN(0,σ2)(2)
将该向量与公式(1)的概率分布进行混合,我们可以得到一个新的概率分布
p ′ ( w ^ i ∣ S ) = Softmax ( p ( w ^ i ∣ S ) + ϵ ) (3) p'(\hat{w}_i\mid S) = \text{Softmax}(p(\hat{w}_i\mid S) + \epsilon)\tag{3} p(w^iS)=Softmax(p(w^iS)+ϵ)(3)
然后对于每个即将被替换的token w ~ \tilde{w} w~,我们根据概率 p ′ ( w ^ i ∣ S ) p'(\hat{w}_i\mid S) p(w^iS)加权融合所有token w ^ i \hat{w}_i w^i的Embedding向量
e ^ w ~ = p w ~ ⋅ M E (4) \hat{\mathbf{e}}_{\tilde{w}}=\mathbf{p}_{\tilde{w}}\cdot\mathbf{M}_E\tag{4} e^w~=pw~ME(4)
其中, p w ~ = { p ′ ( w ^ i ∣ S ) } i = 1 V \mathbf{p}_{\tilde{w}}=\{p'(\hat{w}_i\mid S)\}_{i=1}^V pw~={p(w^iS)}i=1V M E ∈ R V × d \mathbf{M}_E\in \mathbb{R}^{V\times d} MERV×d是MLM模型的词向量矩阵

举个简单的例子解释一下,为了方便,同样还是以替换一个token为例,并且整个Vocabulary只有4个token,词向量的维度为2。首先我们有一句话「She is a good student」,将「good」进行MASK,然后通过MLM模型,预测出概率分布为
p ( w ^ i ∣ S ) = [ 0.5 , 0.1 , 0.1 , 0.3 ] p(\hat{w}_i\mid S)=[0.5, 0.1, 0.1, 0.3] p(w^iS)=[0.5,0.1,0.1,0.3]
从左到右分别是good, perfect, excellent, smart的概率,根据高斯分布 N ( 0 , σ 2 ) \mathcal{N}(0, \sigma^2) N(0,σ2)随机产生的向量为
ϵ = [ − 0.1 , 0.1 , 0.1 , − 0.1 ] \epsilon = [-0.1, 0.1, 0.1, -0.1] ϵ=[0.1,0.1,0.1,0.1]

这里我并没有具体指明方差 σ 2 \sigma^2 σ2到底是多少,因为我懒得算

p ( w ^ i ∣ S ) p(\hat{w}_i\mid S) p(w^iS) ϵ \epsilon ϵ混合后进行Softmax得到新的概率分布为
p ′ ( w ^ i ∣ S ) = [ 0.4 , 0.2 , 0.2 , 0.2 ] p'(\hat{w}_i\mid S) = [0.4, 0.2, 0.2, 0.2] p(w^iS)=[0.4,0.2,0.2,0.2]
假设Embedding矩阵为
M E = [ 0.2 , 0.3 0.1 , 0.5 0.4 , 0.2 0.1 , 0.4 ] \mathbf{M}_E = \begin{bmatrix}0.2,0.3\\0.1,0.5\\0.4,0.2\\0.1,0.4\end{bmatrix} ME=0.2,0.30.1,0.50.4,0.20.1,0.4
那么最终「good」这个位置对应的embedding为
e ^ w ~ = p ′ ( w ^ i ∣ S ) ⋅ M E = [ 0.4 0.2 0.2 0.2 ] T ⋅ [ 0.2 , 0.3 0.1 , 0.5 0.4 , 0.2 0.1 , 0.4 ] = [ 0.2 , 0.34 ] \begin{aligned} \hat{\mathbf{e}}_{\tilde{w}} &= p'(\hat{w}_i\mid S) \cdot \mathbf{M}_E\\ &=\begin{bmatrix}0.4\\0.2\\0.2\\0.2\end{bmatrix}^T\cdot \begin{bmatrix}0.2,0.3\\0.1,0.5\\0.4,0.2\\0.1,0.4\end{bmatrix}\\ &= \begin{bmatrix}0.2, 0.34\end{bmatrix} \end{aligned} e^w~=p(w^iS)ME=0.40.20.20.2T0.2,0.30.1,0.50.4,0.20.1,0.4=[0.2,0.34]
到此为止,不知道大家有没有体会到什么叫「Virtual Data Augmentation」,Virtual本质上就是不用一个真实的token去替换,而是使用一个embedding去替换,而如果你用这个embedding去反查 M E \mathbf{M}_E ME矩阵一般是找不到对应的索引的,也就是说我们生成的这个embedding并不对应一个实际存在的token

Regularized Traning

标题起的很有故事,但本质上就是多引入了一个损失函数,具体来说,现在我们的优化目标为
arg ⁡ min ⁡ θ ∑ i = 1 n L c ( f ( x i ) , y i ) + λ ∑ j = 1 k L r e g ( f ( x i ) , f ( x ^ j ) ) (5) \underset{\theta}{\arg \min } \sum_{i=1}^{n} \mathcal{L}_{c}\left(f\left(x_{i}\right), y_{i}\right)+\lambda \sum_{j=1}^{k} \mathcal{L}_{\mathrm{reg}}\left(f\left(x_{i}\right), f\left(\hat{x}_{j}\right)\right)\tag{5} θargmini=1nLc(f(xi),yi)+λj=1kLreg(f(xi),f(x^j))(5)
其中 f f f表示含有参数 θ \theta θ的预训练模型, n n n为样本个数, k k k表示由一条句子扩增出了 k k k条句子。具体来说,如果是分类任务,则
L c ( θ ) = 1 n ∑ i = 1 n CE ( f ( E i ; θ ) , y i ) (6) \mathcal{L}_c(\theta) = \frac{1}{n}\sum_{i=1}^n \text{CE}(f(\mathbf{E}_i;\theta), y_i)\tag{6} Lc(θ)=n1i=1nCE(f(Ei;θ),yi)(6)
其中, CE ( ⋅ , ⋅ ) \text{CE}(\cdot ,\cdot) CE(,)是Cross-Entropy Loss,可以根据具体任务替换的, E i \mathbf{E}_i Ei表示第 i i i条句子通过Word2Vec之后生成的向量,其维度为[seq_len, emd_dim]

为了防止扩增后的样本与原始样本间的语义产生巨大差距,换句话说,我们希望扩增后的样本与原样本间的分布是接近的,因此论文引入了KL散度作为第二项损失
L reg ( θ ) = 1 k ∑ i = 1 k D s K L ( f ( E i ; θ ) , f ( E ^ i ; θ ) ) (7) \mathcal{L}_{\text{reg}}(\theta)=\frac{1}{k}\sum_{i=1}^k D_{sKL}(f(\mathbf{E}_i;\theta), f(\hat{\mathbf{E}}_i;\theta))\tag{7} Lreg(θ)=k1i=1kDsKL(f(Ei;θ),f(E^i;θ))(7)
其中, k k k指的是原样本扩增出了 k k k个样本, D s K L D_{sKL} DsKL是对称的KL散度,具体来说
D s K L ( p , q ) = D K L ( p , q ) + D K L ( q , p ) 2 (8) D_{sKL}(p, q) = \frac{D_{KL}(p, q) + D_{KL}(q, p)}{2}\tag{8} DsKL(p,q)=2DKL(p,q)+DKL(q,p)(8)
实际上这种方法可以看作是多任务,我们希望模型参数训练到一种境界,这种境界是,不论模型对原样本进行下游任务,还是让模型判断原样本与扩增样本的差距,模型都能做的很好。最后给出论文中的一张图结束这部分(图中一个样本扩增了3条样本)

Results

如果单看原始的准确率对比,似乎提升并不是很大,感觉我随便引入一些trick都能达到甚至超过Virtual Data Augmentation的效果。关键在于第二列「Att Acc」,这代表模型受到攻击时的结果,这部分的提升特别大,表明VDA这种方法确实有很强的抗干扰性,或者说鲁棒性很强

个人总结

实际上前面已经把这篇论文讲的很清楚了,这里没有什么好总结的,但我倒是有一点个人拙见想和大家讨论一下,因为他做MLM任务时,将整个Vocabulary都作为候选集,这样无论是对计算速度还是显存占用都不是很友好,我觉得可以将其改为取出概率最大的前Top k个token,这个k可以取的稍微大一点,例如200, 300等,这样可以保证取到后面一些语义上不那么相近的token的同时,避免对整个Vocabulary进行运算,至少不会生成几万几十万那么夸张的概率分布

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/10471.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

CANoe诊断测试

诊断协议那些事儿 本文为诊断协议那些事儿专栏文章,当我们在开发工程中越来越多的需要使用到总线测试工具,其中包括BUSMASTER、周立功、PCAN、CANOE等,本文将使用德国Vector公司的CANoe介绍诊断测试的基本环境。 文章目录诊断协议那些事儿一…

Python编程从入门到实践 第五章:if语句 练习答案记录

Python编程从入门到实践 第五章:if语句 练习答案记录 练习题导航Python编程从入门到实践 第五章:if语句 练习答案记录5.1 一个简单示例5.2 条件测试5.2.1 检查是否相等5.2.2 检查是否相等时忽略大小写5.2.3 检查是否不相等5.2.4 数值比较5.2.5 检查多个文…

运行yolov5 v6遇到的问题

1. Arial.ttf无法在运行时下载的问题 可以选择用浏览器下载,然后拷贝到docker或者ubuntu下,创建服务器的http访问方式。 具体为: 1.1 下载文件 Arial.ttf 并拷贝到docker或者ubuntu下 1.2 在ubuntu下创建http访问方式: # 安装…

有限元仿真分析误差来源之边界条件,约束和point mass

导读:前不久,我在这里分享了一篇《有限元仿真分析误差来源之材料参数设置,小心为妙》的文章,引发了同行们的关注和讨论。在此感谢仿真秀平台讲师们的批评和指正,一起认真交流技术和进步。今天我将继续带来关于边界条件…

spring data jpa在mysql分页中的实例(一次访问同时获取数据和总数)

一、原生sql语句 mysql中语句如下 select SQL_CALC_FOUND_ROWS sn,max(count) as active_count from sn_state_changed where sn_year zz group by sn limit 0,10; select FOUND_ROWS() as total; 解释: SQL_CALC_FOUND_ROWS 供后面的查询总数sql语句使…

(STM32)从零开始的RT-Thread之旅--SPI驱动ST7735(1)

上一篇: (STM32)从零开始的RT-Thread之旅--GPIO 我使用的开发板是WeAct的H743板子,板子带一个0.96的SPI驱动的LCD,给的有现成的测试用例,看源码应该是ST的工程师写的ST7735的驱动,打算把这个驱动直接拿到RTT工程里面使…

SolidWorks 入门笔记01:草图绘制

全文目录简介1. 草图的创建1.1 在基准面上新建一个二维草图多学一招:退出草图绘制模式,快捷键切换视图1.2 从已有的草图派生新的草图。1.3 在零件的平面上绘制草图多学一招:SolidWorks 中鼠标滚轮放大缩小功能反的解决办法 。2. 基本图形绘制…

【2022秋线上作业-第5次-第11-13周】判断题

1-1 一棵有124个结点的完全二叉树,其叶结点个数是确定的。T 解析: 一棵124个叶节点的完全二叉树,假设n0为叶子节点数,n1为度为1结点数,n2为度为2结点数,则有总结点数为n0n1n2;而n2n0-1123&#…

如何杜绝 spark history server ui 的未授权访问?

如何杜绝 spark history server ui 的未授权访问? 1 问题背景 默认状况下,Spark history Sever ui 是没有任何访问控制机制的,任何用户只要知道 shs 对应的 url,就可以访问链接查看 spark 作业的运行状况。 在证券基金银行等金融行业中&a…

Kotlin 开发Android app(六):Kotlin 中的空判断 问号和感叹号

如果有人对程序的崩溃原因做下统计的话,那么由于对象为空,但是又访问了对象的某个属性而导致的崩溃,也许会是程序崩溃的第一大原因了。 比如我们在使用字符串的时候,变量字符串为空的时候,我们去访问了这个字符串变量的…

2022-11-16 每日打卡:单调栈解决最大矩形问题(一维直方图,二维最大红矩形)

每日打卡:单调栈解决最大矩形问题(一维直方图,二维最大红矩形) 柱状图中最大的矩形 思路 这个题最明显的思路就是:矩形面积底高。 版本1:底的长度可以通过二重循环来完成,高通过循环来寻找最…

44、Spring AMQP 数据转换器

1、操作案例 2、发送一个对象到队列中 3、控制台查看 4、使用消息转换器 5、消费者接收消息, 传递什么类型,就接受什么类型,发送方与接收方所使用的消息转换器必须对应 6、总结分析 默认的消息推送是通过JDK序列化的方式进行的,…

【STM32+cubemx】0029 HAL库开发:HMC5883L磁力计的应用(电子指南针)

今天我们来学习电子磁力计HMC5883L的使用。先介绍磁力计的基础知识,再给一个获取磁力计数据的例子,最后讲解HMC5883L磁力计的校准,以及一些使用中的经验。 1)HMC5883L磁力计的基础知识 磁力计是用来测量磁场强弱(也就…

Android 录音没有声音,设置AudioSource.VOICE_CALL直接MediaRecorder.start异常等系列问题

一、我的需求:来电后,我的三方应用主动开启录音,挂断后结束录音,查验音频 我遇到的问题:录制的音频没有声音。 通过各种尝试,结果如下 :设置不同的录音来源的效果 MediaRecorder API\创建MediaR…

Springboot 结合 MQTT、Redis ,对接硬件以及做消息分发,最佳实践

Springboot 结合 mqtt、redis对接硬件以及做消息分发,最佳实践 一,认识 需要了解EMQX 基本知识原理,不了解的可以查看我之间的博客,以及网上的资料,这里不在过多撰述。 二,开发思路 这里以对接雷达水位计…

【最优化理论】03-无约束优化

无约束优化无约束优化问题无约束优化问题的应用无约束优化问题的最优性条件无约束-凸函数-最优性条件(充要)无约束-一般函数-最优性条件必要条件一阶必要条件:梯度为0二阶必要条件:hessian矩阵半正定充分条件二阶充分条件&#xf…

元宇宙-漫游世界后与Cocos一起看湖南卫视直播

使用参考资源 CocosCreator v3.6.2 cocomat 腾讯开源公共组件框架 Cocos Creator 3D特制 Video MeshRender 播放器(Cocos商店购买) TcPlayer 腾讯开源 Web 播放器 视频流 hls 库 正文 场景漫游引发的思考 元宇宙,虚拟世界。OK,…

【UI编程】将Java awt/swing应用移植到JavaFX纪实

1. 背景 最近想做一个实用的小工具,能屏幕截图,录屏和录制课件,简单的图像处理,和制作gif表情包。翻出了很久以前用Java awt/swing写的一个屏幕截图小程序,能运行,但是屏幕截图到剪贴板后,发现…

深入理解JavaScript-this关键字

先说结论:谁调用它,this 就指向谁 前言 在讲 Function、作用域 时,我们都讲到了 this,因为 JavaScript 中的作用域是词法作用域,在哪里定义,就在哪里形成作用域。而与词法作用域相对应的还有一个作用域叫…

MP157-0-遇见的问题及解决办法

MP157-0-遇见的问题及解决办法1.Win11运行VMware15虚拟机崩溃死机,蓝屏。1.Win11运行VMware15虚拟机崩溃死机,蓝屏。 时间:2022.11.15 解决办法: Hyper-V方案。 打开控制面板-程序-启用或关闭Windows功能,可能你的电…