提示学习soft prompt浅尝,启发了p-tuing

news2024/12/24 3:53:27

一、前言

在高质量标注数据稀缺的工业界来说,少样本学习或者零样本学习的方法特别受欢迎,后面出现过一些少样本和零样本的方法,例如对比学习和prompt等,主流prompt的工作分为离散型和连续型模板。离散型主要还是插入bert特殊的token为主,连续型则是插入数字token离散型可解释性强于连续型,我这里讲的soft prompt则是连续型的。

大型预训练语言模型的规模不断扩大,在许多自然语言处理 (NLP) 基准测试中取得了最先进的结果。自GPT和BERT开发以来,标准做法一直是在下游任务上微调模型,这涉及调整网络中的每个权重(即模型调整)。然而,随着模型变得越来越大,为每个下游任务存储和提供模型的微调变得不切实际。

一个吸引人的替代方案是在所有下游任务中共享一个单一的冻结预训练语言模型,其中所有权重都是固定的。在一个令人兴奋的发展中,GPT-3令人信服地表明,可以通过“上下文”学习来调节冻结模型以执行不同的任务。使用这种方法,用户通过提示设计为给定任务准备模型,即手工制作带有手头任务描述或示例的文本提示。例如,要为情感分析设置模型,可以附加提示“以下电影评论是正面的还是负面的?” 在输入序列之前,“这部电影太棒了!

二、soft prompt

跨任务共享相同的冻结模型极大地简化了服务并允许有效的混合任务推理,但不幸的是,这是以牺牲任务性能为代价的。文本提示需要人工设计,即使是精心设计的提示与模型调优相比仍然表现不佳。例如,在SuperGLUE基准测试中冻结的 GPT-3 175B 参数模型的性能比使用少 800 倍参数 的微调T5 模型低 5 个点。

在EMNLP 2021上发表的“参数高效提示调整的规模力量”中,我们探索了提示调整,一种使用可调软提示调节冻结模型的更有效方法。就像工程文本提示一样,软提示连接到输入文本。但不是从现有的词汇项目中选择,软提示的“标记”是可学习的向量。这意味着可以在训练数据集上端到端优化软提示。除了消除手动设计的需要之外,这还允许提示从包含数千或数百万个示例的数据集中压缩信息。相比之下,由于模型输入长度的限制,离散文本提示通常限制在 50 个以下示例。

要为给定任务创建软提示,我们首先将提示初始化为固定长度的向量序列(例如,20 个token长)。我们将这些向量附加到每个嵌入输入的开头,并将组合序列输入模型。将模型的预测与目标进行比较以计算损失,并将误差反向传播以计算梯度,但是我们仅将这些梯度更新应用于我们的新可学习向量——保持核心模型冻结。虽然以这种方式学习的软提示不能立即解释,但在直观的层面上,软提示正在从标记的数据集中提取有关如何执行任务的证据,其作用与手动编写的文本提示相同,但​​不需要限于离散的语言。

soft prompt是参数越多效果越好,引自Google发表原文。

三、soft prompt实现

import torch
import torch.nn as nn

class SoftEmbedding(nn.Module):
    def __init__(self, 
                wte: nn.Embedding,
                n_tokens: int = 10, 
                random_range: float = 0.5,
                initialize_from_vocab: bool = True):
        """appends learned embedding to 
        Args:
            wte (nn.Embedding): original transformer word embedding
            n_tokens (int, optional): number of tokens for task. Defaults to 10.
            random_range (float, optional): range to init embedding (if not initialize from vocab). Defaults to 0.5.
            initialize_from_vocab (bool, optional): initalizes from default vocab. Defaults to True.
        """
        super(SoftEmbedding, self).__init__()
        self.wte = wte
        self.n_tokens = n_tokens
        self.learned_embedding = nn.parameter.Parameter(self.initialize_embedding(wte,
                                                                               n_tokens, 
                                                                               random_range, 
                                                                               initialize_from_vocab))
            
    def initialize_embedding(self, 
                             wte: nn.Embedding,
                             n_tokens: int = 10, 
                             random_range: float = 0.5, 
                             initialize_from_vocab: bool = True):
        """initializes learned embedding
        Args:
            same as __init__
        Returns:
            torch.float: initialized using original schemes
        """
        if initialize_from_vocab:
            return self.wte.weight[:n_tokens].clone().detach()
        return torch.FloatTensor(n_tokens, wte.weight.size(1)).uniform_(-random_range, random_range)
            
    def forward(self, tokens):
        """run forward pass
        Args:
            tokens (torch.long): input tokens before encoding
        Returns:
            torch.float: encoding of text concatenated with learned task specifc embedding
        """
        input_embedding = self.wte(tokens[:, self.n_tokens:])
        learned_embedding = self.learned_embedding.repeat(input_embedding.size(0), 1, 1)
        return torch.cat([learned_embedding, input_embedding], 1)

from transformers import AutoConfig, AdamW, AutoTokenizer, AutoModel
import torch
import torch.nn as nn
from soft_embedding import SoftEmbedding

n_tokens = 20
initialize_from_vocab = True
tokenizer = AutoTokenizer.from_pretrained("nezha-base-wwm")
config = AutoConfig.from_pretrained("nezha-base-wwm", num_labels=num_class)
config.output_hidden_states = True  # 需要设置为true才输出
model = AutoModel.from_pretrained(self.model_path, config=config)
s_wte = SoftEmbedding(model.get_input_embeddings(), 
                      n_tokens=n_tokens, 
                      initialize_from_vocab=initialize_from_vocab)
model.set_input_embeddings(s_wte)
inputs = tokenizer("May the force be", return_tensors="pt")

# need to pad attention_mask and input_ids to be full seq_len + n_learned_tokens
# even though it does not matter what you pad input_ids with, it's just to make HF happy
inputs['input_ids'] = torch.cat([torch.full((1,n_tokens), 50256), inputs['input_ids']], 1)
inputs['attention_mask'] = torch.cat([torch.full((1,n_tokens), 1), inputs['attention_mask']], 1)

outputs = model(**inputs)

四、总结

soft prompt比较依赖于模型参数大小,更加适合零样本和小样本,如果用来大数据量下微调模型,效果可能会比普通微调差不多或者更差点。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/652618.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

分享AI绘画的方法

曾经,在一个神奇的编程国度里,住着一个名叫小花的程序员。小花喜欢创造和探索新奇的技术,她有一个惊人的能力:她能够根据文字生成相应的图片。这项技术让她成为了这个国度里的传奇人物。人们纷纷向她寻求帮助,希望能够…

flutter:网络请求、json数据转为Model

参考 老孟 flutter: 网络请求-dio http http 是一个可组合,基于Future的库,用于HTTP请求。该软件包包含高级功能和类,可轻松使用HTTP资源。它是多平台的,并且支持移动设备,台式机和浏览器。此软件包为官…

STM32F1x固件库函数学习笔记(一)

文章目录 一、基础知识1、什么是STM322、STM32诞生背景3、STM32分类4、STM32F1X系列命名规则5、STM32F103C8T6最小系统 二、STM32固件库1、初始固件库(1)51单片机的寄存器(2)STC8A通过库函数方式实现LED闪烁(3&#xf…

Spark入门(二)

2.3 Standalone模式 Standalone模式是Spark自带的资源调度引擎,构建一个由Master Worker构成的Spark集群,Spark运行在集群中。 这个Standalone区别于Hadoop的。这里的Standalone是指只用Spark来搭建一个集群,不需要借助其他框架。 2.3.1集…

充能书单|618,买什么都不如买知识!

前言 “IT有得聊”是机械工业出版社旗下IT专业资讯和服务平台,致力于帮助读者在广义的IT领域里,掌握更专业、更实用的知识与技能,快速提升职场竞争力。 点击蓝色微信名可快速关注我们。 一年一度的618又到啦!今年的618就不要乱买…

【Linux】MySQL数据库 (二)

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 MySQL数据库 数据表高级操作克隆表,将数据表的数据记录生成到新的表中清空表,删除表内的所有数据创建临时表创建外键约束,保证数据的完整性…

【计算机网络】第一章 概述(下)

文章目录 第一章 概述1.5 计算机网络的性能指标1.5.1 速率1.5.2 带宽1.5.3 吞吐量1.5.4 时延 1.6 计算机网络体系结构1.6.1 常见的体系结构1.6.2 分层的必要性1.6.4 体系结构中的专用术语 1.8 习题 第一章 概述 1.5 计算机网络的性能指标 常用的 计算机网络 的性能指标有以下 …

ECC算法学习(一)算法公式

ECC 一、ECC简介优缺点运用 二、算法理论基础1. 椭圆曲线的加法2. 椭圆曲线的二倍运算3. 同余运算4. 有限域5. 乘法逆元 三、算法公式1、有限域的负元2、有限域的加法, P Q P Q PQ3. 斜率计算(PQ即要计算P点切线,需要求导)4. 椭…

chatgpt赋能python:PYTHON如何进行累乘操作?

PYTHON如何进行累乘操作? 在PYTHON编程中,累乘操作是指不断地将一个给定数字序列中的数字相乘的过程。这个操作在数学中也被称为阶乘,通常用符号“!”来表示。 在PYTHON中,进行累乘操作的方法主要有两种:使用循环实现…

Java实现TestNg+ExtentReport实现接口测试,并生成测试报告

一 在pom.xml文件中引入TestNg以及ExtentReport包 <dependencies> <!--testNg引入--> <dependency> <groupId>org.testng</groupId> <artifactId>testng</artifactId> <version>6.9.10</version> </de…

学习css样式的第二章

1.CSS 布局 - display 属性 display 属性是用于控制布局的最重要的 CSS 属性。 display 属性 display 属性规定是否/如何显示元素。 每个 HTML 元素都有一个默认的 display 值&#xff0c;具体取决于它的元素类型。大多数元素的默认 display 值为 block 或 inline 块级元素…

chatgpt赋能python:Python编程:如何粘贴代码

Python编程&#xff1a;如何粘贴代码 在Python编程过程中&#xff0c;粘贴代码是一个非常普遍的操作。不幸的是&#xff0c;许多初学者并不知道如何正确地粘贴代码&#xff0c;这可能会导致一些常见的错误和问题。本文将介绍如何正确地粘贴代码以及一些常见的问题和解决方案。…

Wise 的平台工程 KPI 探索之旅

作者&#xff5c;Lambros Charissis 翻译&#xff5c;Seal软件 链接&#xff5c;https://medium.com/wise-engineering/platform-engineering-kpis-6a3215f0ee14 平台即产品&#xff08;PaaP&#xff09;已经成为软件企业构建内部平台的一种流行方式。在众多软件公司争夺市场份…

地球物理专业毕业生毕业后能干高性能计算工程师吗?

很多高校都开设有地球物理专业&#xff0c;但是很多身为地球物理专业的毕业生&#xff0c;很多同学却不清楚以后能做什么工作&#xff0c;做什么工作有前景&#xff0c;十分迷茫。在这里&#xff0c;我们有很多从事高性能计算领域的前地球物理专业学长现身说法——地球物理专业…

Qt弱加密漏洞分析

0x00 漏洞背景 Qt是一个跨平台的C应用程序开发框架&#xff0c;用于创建图形用户界面&#xff08;GUI&#xff09;应用程序、命令行工具、嵌入式系统和网络应用等各种类型的应用。 Qt框架包含的Qt Network&#xff08;网络模块&#xff09;&#xff0c;提供了QNetworkAccessM…

Vue中如何进行游戏开发与游戏引擎集成?

Vue中如何进行游戏开发与游戏引擎集成&#xff1f; Vue.js是一款流行的JavaScript框架&#xff0c;它的MVVM模式和组件化开发思想非常适合构建Web应用程序。但是&#xff0c;如果我们想要开发Web游戏&#xff0c;Vue.js并不是最合适的选择。在本文中&#xff0c;我们将介绍如何…

mycat分库分表中间件介绍,有案例

目录 MyCat分库分表概述水平拆分和垂直拆分安装JDK安装MyCat安装 MyCat案例1、创建数据库2、分片配置&#xff08;schema.xml&#xff09;3、分片配置&#xff08;server.xml&#xff09;4、启动服务5、查看日志&#xff0c;看是否启动成功6、登录MyCat7、查看数据库和表8、创建…

SpringBoot自动配置原理简单分析

说明&#xff1a;在SpringBoot项目中&#xff0c;我们添加了许许多多的注解&#xff0c;这些注解提高了开发效率。这是因为SpringBoot在项目启动时&#xff0c;帮我们自动装配了大量的Bean对象&#xff0c;可以通过分析源码查看自动装配的大致原理。 第一步&#xff1a;Spring…

python:并发编程(十)

前言 本文将和大家一起探讨python的多协程并发编程&#xff08;上篇&#xff09;&#xff0c;使用内置基本库asyncio来实现并发&#xff0c;先通过官方来简单使用这个模块。先打好基础&#xff0c;能够有个基本的用法与认知&#xff0c;后续文章&#xff0c;我们再进行详细使用…

【程序人生-Hello‘s P2P】哈尔滨工业大学深入理解计算机系统大作业

计算机系统 大作业 题 目 程序人生-Hello’s P2P 专 业 xxxx 学 号 2021xxxx 班 级 210xxxx 学 生 xx 指 导 教 师 xxx 计算机科学与技术学院 2023年5月 摘 要 HelloWorld是每个程序员接触的第一个程序&#xff0c;表面上平平无奇的它背后却是由操作系统许多设计精巧的机制支撑…