【简单、高效、性能好】SetFit:无需Prompts的高效小样本学习

news2025/1/11 11:45:07

重磅推荐专栏: 《Transformers自然语言处理系列教程》
手把手带你深入实践Transformers,轻松构建属于自己的NLP智能应用!

1. 概要

使用预训练语言模型的小样本学习(处理只有少量标签或没有标签的数据)已成为比较普遍的解决方案。
SetFit:一种用于对 Sentence Transformers 进行少量微调的高效框架。SetFit 用很少的标记数据实现了高精度——例如,在客户评论 (CR) 情绪数据集上每个类只有 8 个标记样本,SetFit 在 3k 个样本的完整训练集上与微调 RoBERTa Large 相比,如图1-1所示,具有竞争力表现:
图1-1,与标准微调相比,SetFit 的样本效率和抗噪能力要高得多
与其他小样本学习方法相比,SetFit 有几个独特的特点:

  • 没有提示(prompts )或语言器(verbalisers):当前的小样本微调技术需要手工制作的提示(prompts )或语言器(verbalisers)将样本转换为适合底层语言模型的格式。SetFit 通过直接从少量带标签的文本示例生成丰富的embeddings 来完全免除prompts 。

  • 训练速度快:SetFit 不需要像 T0 或 GPT-3 这样的大型模型来实现高精度。因此,训练和运行推理的速度通常快一个数量级(或更多)。

  • 多语言支持:SetFit 可以与 Hub 上的任何 Sentence Transformer 一起使用,这意味着你可以通过简单地微调多语言checkpoint来对多种语言的文本进行分类。

  • 论文: https://arxiv.org/pdf/2209.11055.pdf

  • 代码:https://github.com/huggingface/setfit

2. 原理

SetFit原理比较简单,它设计考虑了效率和简单性。SetFit 首先在少量标记示例(通常每个类 8 或 16 个)上微调 Sentence Transformer 模型。接下来是在微调的 Sentence Transformer 生成的embeddings上训练分类器头。SetFit 利用 Sentence Transformers 的能力基于成对的句子生成密集embeddings 。如图2-1所示:
图2-1,SetFit的两阶段训练流程
图2-2,句子对生成伪代码

  • 在初始微调阶段,它通过对比训练利用有限的标记输入数据,其中正负对由类内和类外选择创建,如图2-2所示。然后,Sentence Transformer 模型对这些对(或三元组)进行训练,并为每个样本生成密集向量。
  • 在第二步中,分类头使用各自的类标签对编码embeddings进行训练。在推理时,未见过的样本通过微调的 Sentence Transformer,生成一个embedding ,当将其送到分类头时,输出一个类标签预测结果。

只需将基本的 Sentence Transformer 模型切换为多语言模型,SetFit 就可以在多语言环境中无缝运行。

3. 实验

3.1 效果表现

虽然基于比现有的少样本方法小得多的模型,但 SetFit 在各种基准测试中的表现与sota的少样本方法相当或更好。如图3-1所示,在RAFT(一个 few-shot 分类基准)上,具有 3.55 亿个参数的 SetFit Roberta 优于 PET 和 GPT-3。它仅仅在人类平均表现和 110 亿个参数 T-few(这个模型的大小是 SetFit Roberta 的 30 倍) 水平之下。SetFit 在 11 项 RAFT 任务中的 7 项上也优于人类基线
图3-1,RAFT 排行榜上的突出方法(截至 2022 年 9 月)
在其他数据集上,SetFit 在各种任务中表现出稳健性。如下图3-2所示,每个类只有 8 个示例,它基本上优于 PERFECT、ADAPET 和微调的 vanilla transformer。SetFit 也取得了与 T-Few 3B 相当的结果,尽管它无需提示且体积小 27 倍
图3-2,在 3 个分类数据集上将 Setfit 性能与其他方法进行比较

3.2 训练和推理速度

由于 SetFit 使用相对较小的模型实现了高精度,因此它的训练速度非常快,而且成本要低得多。例如,使用 8 个标记示例在 NVIDIA V100 上训练 SetFit 仅需 30 秒,成本为 0.025 美元。相比之下,训练 T-Few 3B 需要 NVIDIA A100,耗时 11 分钟,同一实验的成本约为 0.7 美元——高出 28 倍。事实上,SetFit 可以像 Google Colab 上的那样在单个 GPU 上运行,你甚至可以在几分钟内在 CPU 上训练 SetFit!如图3-3所示,SetFit带来了提速,模型性能却与T-Few 3B相当。预测和蒸馏 SetFit 模型也可以获得类似的收益,可以带来 123 倍的加速!
图3-3,比较 T-Few 3B 和 SetFit (MPNet) 的训练成本和平均性能,每个类有 8 个标记样本

4. 实践:零样本文本分类

SetFit还可以做零样本文本分类。我们需要做的第一件事是创建一个合成样本的虚拟数据集。我们可以通过将 add_templated_examples() 函数来完成此操作。此函数需要一些主要内容:

  • 用于分类的候选标签列表。 我们将在此处使用参考数据集中的标签。
  • 用于生成示例的模板。 默认情况下,它是“This sentence is {}”,其中{}将由候选标签名称之一填充
  • 样本量 N,这将为每个类创建 N 个合成示例。 作者发现 N=8 通常效果最好。
dataset_id = "emotion"
model_id = "sentence-transformers/paraphrase-mpnet-base-v2"
from datasets import load_dataset
reference_dataset = load_dataset(dataset_id)

# 从“label”列中提取 ClassLabel 特征
label_features = reference_dataset["train"].features["label"]
# 用于分类的标签名称
candidate_labels = label_features.names
candidate_labels
['sadness', 'joy', 'love', 'anger', 'fear', 'surprise']
from datasets import Dataset
from setfit import add_templated_examples

# 用合成样本填充的虚拟数据集
dummy_dataset = Dataset.from_dict({})
train_dataset = add_templated_examples(dummy_dataset, candidate_labels=candidate_labels, sample_size=8)
train_dataset

由于我们的数据集有 6 个类别,我们选择的样本大小为 8,因此我们的合成数据集包含 6×8=48 个样本。

Dataset({
    features: ['text', 'label'],
    num_rows: 48
})

我们看几个例子:

train_dataset.shuffle()[:3]
{'text': ['This sentence is love',
  'This sentence is fear',
  'This sentence is joy'],
 'label': [2, 4, 1]}

用这样虚拟数据集来微调模型,在预测看看效果:

from setfit import SetFitModel

model = SetFitModel.from_pretrained(model_id)

from setfit import SetFitTrainer

trainer = SetFitTrainer(
    model=model,
    train_dataset=train_dataset,
    eval_dataset=reference_dataset["test"]
)

trainer.train()
zeroshot_metrics = trainer.evaluate()
zeroshot_metrics
{'accuracy': 0.5345}

我们在尝试一下用 Hugging Face 的 zero-shot-classification:

from transformers import pipeline

pipe = pipeline("zero-shot-classification", device=0)

zeroshot_preds = pipe(reference_dataset["test"]["text"], batch_size=16, candidate_labels=candidate_labels)

zero-shot-classification pipeline 默认用的是 facebook/bart-large-mnli。注意,该方法 生成预测结果所需的时间比 SetFit 长将近 5 倍! 好的,那么它的性能如何?

preds = [label_features.str2int(pred["labels"][0]) for pred in zeroshot_preds]

import evaluate

metric = evaluate.load("accuracy")
transformers_metrics = metric.compute(predictions=preds, references=reference_dataset["test"]["label"])
transformers_metrics

与 SetFit 相比,这种方法的性能要差得多:

{'accuracy': 0.3765}

看来 SetFit 真的是——即简单,又高效,还性能好 !666666666666666…

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/42779.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

(附源码)计算机毕业设计Java大学生学科竞赛报名管理系统

项目运行 环境配置: Jdk1.8 Tomcat8.5 Mysql HBuilderX(Webstorm也行) Eclispe(IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持)。 项目技术: Springboot mybatis Maven Vue 等等组成,B/…

【Java语言】Java类与对象的详细教程,一看就会

Java类与对象 文章目录Java类与对象1. 类与对象的初步认知2. 类和类的实例化3. 类的成员3.1 字段/属性/成员变量3.1.1认识 null3.1.2字段就地初始化3.2 方法 (method)3.3 static 关键字3.4方法调用易错区分4. 封装4.1 private实现封装4.2 getter和setter方法5.构造方法5.1 基本…

【密码学基础】Oblivious Transfer(不经意传输)

头一次开始学密码学相关的东西,未来的主要研究方向包括了隐私计算,即隐私保护下的机器学习算法。 0 举个实际的例子 引用博客OT(Oblivious Transfer,不经意传输)协议详解提到的例子,我们这里考虑1-out-of-…

美团应届生面试第一问:Object o = new Object()占用多少字节?

文章目录工具查看内存分配Java内存模型访问对象方式GC为什么Survivor要分为两个区域(S0和S1)?Survivor 为什么不分更多块呢?对象的生命周期小知识工具查看内存分配 Object o new Object();占用多少字节,我们借助open…

重要公告 | 论坛域名更换,请务必及时收藏

论坛的小伙伴们: 为进一步规范网站域名,自2022年11月16日起,“西门子低代码开发者论坛”的域名由:https://forum.mendix.tencent-cloud.com/,正式变更为:https://marketplace.siemens.com.cn/low-code-com…

Kamiya丨Kamiya艾美捷人和动物LBP ELISA说明书

Kamiya艾美捷人和动物LBP ELISA预期用途: 人和动物LBP ELISA已被开发用于定量测定天然和血清,血浆和培养基中的重组人LBP。也适用于牛,猪,兔和狗LBP。仅供研究使用。不用于诊断程序。 Kamiya艾美捷人和动物LBP ELISA原理&#xf…

地理计算 | 计算两个坐标点射线的交点(前方交会)

1 前言 前方交会--- 又称为测角交会,是指从相邻两个已知点向待定点观测两个水平角,用以计算待定点的坐标。 如图所示,点 A、B 的坐标已知。 通过观测角 A 和角 B 求出点 P 坐标的定位方法被称之为“角度前方交会”; 通过观测方…

汽车租赁系统毕业设计,汽车租赁管理系统设计与实现,毕业设计论文毕设作品参考

功能清单 【后台管理员功能】 广告管理:设置小程序首页轮播图广告和链接 留言列表:所有用户留言信息列表,支持删除 会员列表:查看所有注册会员信息,支持删除 资讯分类:录入、修改、查看、删除资讯分类 录入…

代码随想录训练营day46, 单词拆分和多重背包

今天就这一道题, 但还是有难度的 单词就是物品, 字符串s就是背包, 单词能否组成字符串s, 就是问物品能不能把背包装满 确定dp数组含义: 字符串长度为i的话, dp[i]为true, 表示可以拆分, j是分割指针确定递推公式: 如果确定dp[j]是true, 且[j , i]这个区间的子串出现在字典里,…

案例-Shell定时采集数据到HDFS

1. 准备工作 创建日志文件存放的目录 /export/data/logs/log,执行命令:mkdir -p /export/data/logs/log 创建待上传文件存放的目录/export/data/logs/toupload,执行命令:mkdir -p /export/data/logs/toupload 查看创建的目录树结…

FSC在全球范围内增强品牌相关度,促进公众理解

【FSC在全球范围内增强品牌相关度,促进公众理解】 FSC品牌标识 “森林与共,生生不息”将逐渐精简,同时覆盖更多语种。 加深消费者对FSC的理解 近年来,FSC品牌认知度不断提高,超过半数的全球消费者认可并信任“小树”标…

为什么劝你要学习Golang以及GO语言(Go语言知识普及)

Go语言 一、 Go语言的由来 Go语言亦叫Golang语言,是由谷歌Goggle公司推出。 传统的语言比如c,大家花费太多时间来学习如何使用这门语言,而不是如何更好的表达写作者的思想,同时编译 花费的时间实在太长,对于编写-编译…

C语言只推荐这1本宝藏书,你读过吗?

入门的大家随便搜搜学起来都不会出错,进阶的推荐1本豆瓣评分9.1,这本经典之作真正地让人搞懂了烦人的指针。 指针为什么如此重要?C语言圈内有一句经典的自嘲:C语言就只有指针可以用了。如果你干掉struct、干掉union、干掉数组、甚…

html在线阅读小说网页制作模板 小说书籍网页设计 大学生静态HTML网页源码 dreamweaver网页作业 简单网页课程成品

🎉精彩专栏推荐 💭文末获取联系 ✍️ 作者简介: 一个热爱把逻辑思维转变为代码的技术博主 💂 作者主页: 【主页——🚀获取更多优质源码】 🎓 web前端期末大作业: 【📚毕设项目精品实战案例 (10…

管道通信: 有名管道 无名管道,行业大牛通通教会你

管道是一种最古老也是最基本的系统IPC形式,管道就像现实中的水管,水就像数据,它是消息传递的一种特殊方式,管道机制必须提供三方面的协调能力:互斥、同步和确定对方的存在。在Linux中是一种使用非常频繁的通信机制。从…

链表剖析及自己手撸“单链表“实现基本操作(初始化、增、删、改等)

一. 基础 1. 前言 链式存储结构,又叫链表,逻辑上相邻,但是物理位置可以不相邻,因此对链表进行插入和删除时不需要移动数据元素,但是存取数据的效率却很低,分为三类: (1).单(向)链表&#xff1…

【iconfont图标】vue引入并使用阿里巴巴iconfont图标流程

前言 为什么要使用阿里图标库: 图标现在是很多地方都会用到的 一般我使用的时候都是直接在ui库中比如elementul自带的一些 有时候哪怕是感觉图标不是非常适合也是用的elementul图标,主要原因是懒 因为能直接复制的,就懒得再去阿里图标库下载…

如何让Java项目兼容更多的客户端设备(一)

如何让Java项目兼容更多的客户端设备(一) 引入 HTTP访问是无状态的,(服务器不知道是不是你访问的)所以我们不知道每次登录的是谁 如果想实现每次登录不用重复登录,最简单的就是让浏览器记住用户名和密码…

球面距离计算方式(杭州到各城市的球面距离计算球面距离)

1)杭州到各城市的球面距离 1、数据来源:自主计算 2、时间跨度:至今 3、区域范围:368个城市 4、指标说明:利用城市经纬度,计算球面距离 部分数据如下: (2)计算两个点之…

Sentinel配置持久化到Nacos实现流控熔断

控制台 jar 下载:github.com/alibaba/Sen… 启动参数 # 将控制台自身接入到sentinel nohup java -jar -Dproject.namesentinel-dashboard -Dcsp.sentinel.dashboard.serverlocalhost:8181 sentinel-dashboard-1.8.5.jar --server.port8181 &> sentinel.log …