论文笔记--Prompt Consistency for Zero-Shot Task Generalization

news2024/11/16 3:51:26

论文笔记--Prompt Consistency for Zero-Shot Task Generalization

  • 1. 文章简介
  • 2. 文章概括
  • 3 文章重点技术
    • 3.1 Prompt-based zero-shot task generalization
    • 3.2 Prompt Consistency Training
    • 3.3 如何防止遗忘和退化?
  • 4. 文章亮点
  • 5. 原文传送门

1. 文章简介

  • 标题:Prompt Consistency for Zero-Shot Task Generalization
  • 作者:Chunting Zhou, Junxian He, Xuezhe Ma, Taylor Berg-Kirkpatrick, Graham Neubig
  • 日期:2022
  • 期刊:Arxiv preprint

2. 文章概括

  文章基于prompt的一致性学习给出了一种zero-shot task generalization(零样本泛化学习)的无监督方法。数值实验表明,文章提出的指令一致性学习方法只需在几个prompt、几十个样本上进行训练,就可以在NLI等NLP任务上追平SOTA水平。
  文章整体架构如下
整体架构

3 文章重点技术

3.1 Prompt-based zero-shot task generalization

  首先简单介绍下zero-shot task generalization(零样本泛化学习):给定输入 x ∈ X x\in \mathcal{X} xX,零样本泛化学习旨在学习一个预训练模型PLM预测出 y ∈ Y y\in \mathcal{Y} yY,其中PLM未在数据集 X \mathcal{X} X上训练过。零样本泛化学习要求模型可以泛化出一个新的表达式 f : X → Y f: \mathcal{X} \to \mathcal{Y} f:XY,而非仅仅在数据集上具有泛化能力。
  给定prompt r r r, r r r包含一个输入模板 r x r_x rx、输出模板 r y r_y ry以及待放入模板的元数据 x , y x, y x,y,我们可以得到prompt-based输入: r x ( x ) , r y ( y ) r_x(x), r_y(y) rx(x),ry(y)。基于prompt的学习方法一般用 p θ ( r y ( y ) ∣ r x ( x ) ) p_{\theta} (r_y(y)|r_x(x)) pθ(ry(y)rx(x))来计算输出的概率 q ( y ∣ x , r ) ) q(y|x, r)) q(yx,r)),其中 θ \theta θ表示模型的参数。本文重点关注NLP的分类任务,则可以通过如下公式计算输出的概率: q ( y ∣ x , r ) = p θ ( r y ( y ) ∣ r x ( x ) ) ∑ y ′ ∈ Y p θ ( r y ( y ′ ) ∣ r x ( x ) ) (1) q(y|x, r) = \frac{p_{\theta} (r_y(y)|r_x(x))}{\sum_{y'\in\mathcal{Y}} p_{\theta} (r_y(y')|r_x(x))}\tag{1} q(yx,r)=yYpθ(ry(y)rx(x))pθ(ry(y)rx(x))(1)

3.2 Prompt Consistency Training

   文章的方法需要无标注的数据集 { x 1 , … , x N } \{x_1, \dots, x_N\} {x1,,xN} K K K个不同的prompt { ( r x 1 , r y 1 ) , … , ( r x K , r y K ) } \{(r_x^1, r_y^1), \dots, (r_x^K, r_y^K)\} {(rx1,ry1),,(rxK,ryK)}。其中无标注的数据集可以来自任意NLP(分类)任务的训练数据集或测试数据集,也可以来自我们要测试的任务的数据集。prompt可直接采用Public Pool of Prompts(p3)数据集里的prompt。
   传统的一致性训练会扰乱样本,使得扰乱后的样本和之前的样本得到的输出尽可能一致。本文希望学习prompt级别的一致性,即不同prompt在单个样本上的学习结构尽可能一致。这样做可以1) 概念非常简单 2)缓解PLM“输入不同prompt结果不一致”的问题。
  损失函数定义如下 L = − E x ∈ p d ( x ) E r i , e r j ∈ p ( r ) E y ^ ∈ q ^ ( y ∣ x , r i ) log ⁡ p θ ( r y j ( y ^ ) ∣ r x j ( x ) ) \mathcal{L} = -\mathbb{E}_{x\in p_d(x)} \mathbb{E}_{r^i, er^j\in p(r)} \mathbb{E}_{\hat{y} \in \hat{q}(y|x,r^i)} \log p_{\theta} (r_y^j(\hat{y})|r_x^j(x)) L=Expd(x)Eri,erjp(r)Ey^q^(yx,ri)logpθ(ryj(y^)rxj(x))
, p d p_d pd是数据集的分布, p ( r ) p(r) p(r)表示 K K K个prompt的随机prompt对的均匀分布, q ^ \hat{q} q^定义为式(1)的条件分布。这里简单解释下,如图所示,给定prompt r i , r j r^i, r^j ri,rj,我们首先预测 y ^ ∈ q ^ ( y ∣ x , r i ) \hat{y}\in \hat{q}(y|x, r^i) y^q^(yx,ri),即当promt为 r i r^i ri时得到输出 y ^ \hat{y} y^。当prompt为 r j r^j rj时,我们希望最大化输出结果为 y ^ \hat{y} y^(即和 r i r^i ri输出相同)的概率 p θ ( r y j ( y ^ ) ∣ r x j ( x ) ) p_{\theta} (r_y^j(\hat{y})|r_x^j(x)) pθ(ryj(y^)rxj(x)),取负对数和期望之后,即得到上述损失函数。我们称上述训练方法为swarm distillation。

3.3 如何防止遗忘和退化?

  如果直接采用上述方法进行训练,则我们很容易collapse,得到一个平凡解:所有prompt、所有样本均输出同一个结果可以实现损失函数最小。另一方面,训练后的模型可以能忘记之前的知识,即castrophic forgetting。为了避免collapse和catastrophic forgetting,文章提出下述两种方法:

  1. LoRA:文章是在T0模型上层进行训练的,为了不发生灾难性遗忘,文章采用了LoRA方法,即通过两个低阶矩阵的乘积进行迭代学习,具体如下图所示。在实际训练时我们将LoRA应用到Transformer每一个前馈层。
    loRA
  2. Fleiss’ Kappa:由于我们没有标注数据作为validation set,从而很难选择一个最佳的checkpoint作为最终模型。为此文章采用了Fleiss’ Kappa指标来度量模型的效果。首先,我们定义一致性概率。对给定的样本 x i x_i xi,记所有 K K K个prompt中预测输出为第 j j j个label的prompt数量为 n i j n_{ij} nij,则对该样本,任意两个prompt给出相同的预测结果的概率为 p i = ∑ j ( n i j 2 ) / ( K 2 ) = ∑ j n i j ( n i j − 1 ) / K ( K − 1 ) p_i = \sum_j \binom {n_{ij}}2 /\binom K2 = \sum_{j} n_{ij}(n_{ij} - 1) / K(K-1) pi=j(2nij)/(2K)=jnij(nij1)/K(K1),所有样本的绝对一致性为 P ‾ = ∑ i p i \overline{P} = \sum_i p_i P=ipi。另一方面,第 j j j个label的占比为 q j = ∑ i n i j / N K q_j = \sum_i n_{ij}/NK qj=inij/NK,则 P ‾ e = ∑ j q j 2 \overline{P}_e = \sum_j q_j^2 Pe=jqj2表示任意两个prompts按照标签的分布随机预测结果一致的概率。当所有 q j q_j qj均相等时, P ‾ e \overline{P}_e Pe最小,即预测的标签随机分布。最终得到Fleiss’ kappa度量为 κ = P ‾ − P ‾ e 1 − P ‾ e ∈ ( − 1 , 1 ) \kappa = \frac {\overline{P} - \overline{P}_e}{1 - \overline{P}_e} \in (-1, 1) κ=1PePPe(1,1),其中 P ‾ e \overline{P}_e Pe越大, κ \kappa κ越小,即预测的结果如果被一个类别主导,则 κ \kappa κ会被惩罚。

4. 文章亮点

  文章提出了一种基于prompt一致性的zero-shot task generation学习方法swarm distillation,且采用了LoRA和Fleiss’ Kappa方法避免学习灾难性遗忘或学习结果collapse。文章在多个NLP下游任务上进行了验证,发现swarm distillation在多个任务上表现超过SOTA。此外,数值实验表明,swarm distillation只需要4个prompt,10+个样本就可以对源模型(T0)进行提升。
  但实验也表明,swarm distillation方法在增加到一定样本量之后性能就达到了饱和,当我们有很多标记样本可用的时候,性能可能不及监督微调。未来可以将swarm distillation与few-shot少样本学习或预训练相结合来实现在标记样本上的性能提升。

5. 原文传送门

Prompt Consistency for Zero-Shot Task Generalization

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/677135.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【numpy模块上}——数据分析01

目录索引 介绍:用处与特点:构成:导包:创建数组: numpy常用方法:常用属性查看:*获取秩的大小:**获取数组形状:**获取元素个数:**获取元素类型:**获…

行为型设计模式10-解释器模式

🧑‍💻作者:猫十二懿 ❤️‍🔥账号:CSDN 、掘金 、个人博客 、Github 🎉公众号:猫十二懿 解释器模式 1、解释器模式介绍 解释器模式(Interpreter Pattern)是一种行为设…

Kafka系列之:对源连接器的的Exactly-Once支持

Kafka系列之:对源连接器的的Exactly-Once支持 一、背景二、目标三、公共接口四、连接器 API 扩展五、REST API验证六、新指标七、计划变更八、任务计数记录九、重新平衡的准备十、源任务启动十一、领导者访问配置主题十二、用于隔离事务生产者的管理 API十三、解决了…

论文阅读 - SegFormer

文章目录 1 概述2 模型说明2.1 总体结构2.2 Hierarchical Transformer Encoder2.3 Lightweight All-MLP Decoder 3 SegFormer和SETR的比较参考资料 1 概述 图像分割任务和图像分类任务是非常相关的,前者是像素级别的分类,后者是图像级别的分类。基于分类…

不到3000块,搭建IT人的实验平台!性能可媲美服务器!

作为IT从业者,特别是运维这个岗位,没有自己的实验平台真的特别难受,那么如何搭建自己的实验平台呢?这是我最近思考并付诸实践的一个事情,最终找到了自己觉得比较可以的方案。 01 我的需求是什么? 大内存容量…

TypeScript——类(class)

ES6 中类的用法 下面我们先回顾一下 ES6 中类的用法,更详细的介绍可以参考 ECMAScript 6 入门 - Class 属性和方法 使用 class 定义类,使用 constructor 定义构造函数。 通过new生成新实例的时候,会自动调用构造函数。 class Person{con…

leetcode877. 石子游戏(动态规划-java)

石子游戏 leetcode877. 石子游戏题目描述暴力递归代码演示 动态规划 动态规划专题: leetcode877. 石子游戏 来源:力扣(LeetCode) 链接:https://leetcode.cn/problems/stone-game 题目描述 Alice 和 Bob 用几堆石子在做游戏。一共有…

HTTP与Fiddler使用

HTTP与Fiddler使用 HTTP与Fiddler使用FidderHTTP的报文结构:其他请求头User-agentReferer和OrigincookieHTTP状态码 HTTP与Fiddler使用 HTTP协议是使用十分广泛的应用层协议,也是一个可以由程序员进行设置的一个协议。该协议的结构规定了浏览器与万维网…

【C++】通讯录的基本实现,附有源码分享

目录 1、运行环境 2、系统实现功能 2.1菜单功能 2.2退出通讯录功能 2.3添加联系人功能 2.4显示联系人功能 2.5删除联系人功能 2.6查找联系人功能 2.7修改联系人功能 2.8清空联系人功能 2.9动态扩容功能 2.10选择优化功能 2.11文件操作 3、源码分享 1、运行环境 …

【备战秋招】每日一题:2023.04.26-华为OD机式-第三题-MC方块

在线评测链接:P1231 题目内容 MC最新版本更新了一种特殊的方块,幽匿催发体。这种方块能够吸收生物死亡掉落的经验并感染周围方块,使其变成幽匿块。Steve想要以此为基础尝试搭建一个经验仓库,他来到了创造超平坦模式,在只有草方块…

【微信小程序开发】第 7 课 - 小程序的常用组件

欢迎来到博主 Apeiron 的博客,祝您旅程愉快 ! 时止则止,时行则行。动静不失其时,其道光明。 目录 1、缘起 2、小程序中组件的分类 3、常用的视图容器类组件 3.1、view 组件 3.2、scroll - view 组件 3.3、swiper 和 swiper…

blfs:为lfs虚拟机增加桌面02

参考书籍: BLFS11.3 LFS11.3(这里面有软件安装的详细说明) 树莓派Linux操作系统移植(这里面有桌面系统的脉络梳理) 参考视频 https://www.youtube.com/watch?vcavxyXBgJ6Q&listPLyc5xVO2uDsBK_3VZOek8ICsxewOO4DU…

Vue3 网络请求——axios 高级用法之 axios 拦截器实战与并发请求

文章目录 📋前言🎯关于拦截器🎯项目创建🎯代码分析🎯补充:并发请求🧩axios.all() 和 Promise.all() 的区别 📝最后 📋前言 Axios 是一个流行的基于 Promise 的 HTTP 客户…

机器学习中的多分类问题

文章标题:机器学习中的多分类问题 机器学习中的分类问题可以大致分为二分类和多分类两种。在二分类问题中,模型需要将输入数据分为两类;而在多分类问题中,模型需要将输入数据分为多个类别。本文将介绍机器学习中的多分类问题及其…

C语言指针类型,8个例子给你讲明白

0.问题 知乎上回答了一个粉丝问题, 结果这兄弟又连续问了几个问题: 好吧,帮人帮到底,送佛送到西!给你讲彻底点吧! 1. int va; 这是一个整型变量,32位CPU的话,占有32个bite 2. in…

Redis入门(1)

1.NOSQL概述 1.1.什么是NOSQL NoSQL,泛指非关系型的数据库。随着互联网web2.0网站的兴起,传统的关系数据库在处理web2.0网站,特别是超大规模和高并发的SNS类型的web2.0纯动态网站已经显得力不从心,出现了很多难以克服的问题&…

设计模式之享元模式笔记

设计模式之享元模式笔记 说明Flyweight(享元)目录享元模式示例类图抽象图形类I图形类L图形类O图形类工厂类测试类 说明 记录下学习设计模式-享元模式的写法。JDK使用版本为1.8版本。 Flyweight(享元) 意图:运用共享技术有效地支持大量细粒度的对象。 结构: 其中&#xff1…

MCU(Cortex - M3/M4)启动加载过程和内存分配原理 笔记

最近发现对基础不太熟悉,写篇笔记记录一下MCU启动到用户C语言运行,之前做了那些工作,同时flash和Ram又分别保存了那个数据,每一段又是什么意义,方便后续自己忘记了,查阅。 一、 MCU启动 在MCU上电/复位之后…

WireShark常用协议抓包与原理分析

1.ARP协议(地址解析协议) nmap 发现网关nmap -sn 192.168.133.2wireshark 抓请求包和响应包 arp请求包内容 arp响应包内容 总结:请求包包含包类型(request),源IP地址,源MAC地址,目标IP地址,目标MAC地址(未知,此处为全0);响应包包含包类型(reply),源IP地址,源…

DAY28:回溯算法(三)组合总和Ⅲ+电话号码字母组合

文章目录 216.组合总和Ⅲ思路树形结构 完整版debug测试逻辑错误:没有输出 剪枝操作剪枝版本continue的用法剪枝最后是continue还是return的讨论 17.电话号码的字母组合思路树形结构 伪代码字符串中的字符2转化成int的方法字符串字符与int转换补充字符串与字符 完整版…